]> git.cworth.org Git - zombocom-ai/commitdiff
Add generate_image.py script
authorCarl Worth <cworth@cworth.org>
Wed, 7 Dec 2022 07:55:29 +0000 (23:55 -0800)
committerCarl Worth <cworth@cworth.org>
Wed, 7 Dec 2022 09:24:55 +0000 (01:24 -0800)
Which uses the dreamstudio interface to Stable Diffusion to generate
images from a text prompt. This does require an API key and has non-zero
cost, but I'm willing to do this for now, (and fortunately, the Stable
Diffusion code itself is entirely open-source so I could switch to my
own hardware when I decide to).

generate-image.py [new file with mode: 0755]

diff --git a/generate-image.py b/generate-image.py
new file mode 100755 (executable)
index 0000000..229428f
--- /dev/null
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+import io
+import warnings
+import argparse
+import json
+from PIL import Image
+from stability_sdk import client
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+
+DOCUMENT_ROOT = "/srv/cworth.org/zombocom"
+IMAGES_PATH = "/images"
+
+def normalize_filename(fn):
+    valid = "-_."
+    out = ""
+    for c in fn:
+      if str.isalpha(c) or str.isdigit(c) or (c in valid):
+        out += c
+      else:
+        out += "_"
+    return out
+
+def save_artifact(artifact, prompt):
+    counter = 1
+    seed = artifact.seed
+    while True:
+        if counter > 1:
+            base = "{}_{}_{}.png".format(seed, prompt, counter)
+        else:
+            base = "{}_{}.png".format(seed, prompt)
+        base = normalize_filename(base)
+        filename = "{}{}/{}".format(DOCUMENT_ROOT, IMAGES_PATH, base)
+        if not os.path.exists(filename):
+            break
+        counter = counter + 1
+
+    img = Image.open(io.BytesIO(artifact.binary))
+    img.save(filename)
+    return {"seed": seed, "prompt": prompt,
+            "filename": filename.removeprefix(DOCUMENT_ROOT)}
+
+def main() -> None:
+    os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
+
+    if not 'STABILITY_KEY' in os.environ:
+        print("Error: STABILITY_KEY must be set with a valid key")
+        sys.exit(1)
+
+    parser = argparse.ArgumentParser(
+        usage="%(prog)s [--seed=SEED] Prompt string here...",
+        description="Generate text-from-image using Stable Diffusion (via dreamstudio)"
+    )
+    parser.add_argument("--seed", nargs='?', type=int, required=False, default=None,
+                        help="seed value to use (defaults to random)")
+    parser.add_argument("prompt", nargs='+',
+                        help="text-to-image prompt")
+
+    args = parser.parse_args()
+
+    stability_api = client.StabilityInference(
+        key=os.environ['STABILITY_KEY'],
+        verbose=True,
+        engine="stable-diffusion-512-v2-0"
+    )
+
+    prompt_string = " ".join(args.prompt)
+
+    if args.seed:
+        answers = stability_api.generate(seed=args.seed, prompt=prompt_string)
+    else:
+        answers = stability_api.generate(prompt=prompt_string)
+
+    results = []
+    for resp in answers:
+        for artifact in resp.artifacts:
+            if artifact.finish_reason == generation.FILTER:
+                warnings.warn("Request tripped over API safety filters")
+            elif artifact.type == generation.ARTIFACT_IMAGE:
+                results.append(save_artifact(artifact, prompt_string))
+    print(json.dumps(results))
+
+if __name__ == "__main__":
+    main()