+#!/usr/bin/env python3
+
+import os
+import sys
+import io
+import warnings
+import argparse
+import json
+from PIL import Image
+from stability_sdk import client
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+
+DOCUMENT_ROOT = "/srv/cworth.org/zombocom"
+IMAGES_PATH = "/images"
+
+def normalize_filename(fn):
+ valid = "-_."
+ out = ""
+ for c in fn:
+ if str.isalpha(c) or str.isdigit(c) or (c in valid):
+ out += c
+ else:
+ out += "_"
+ return out
+
+def save_artifact(artifact, prompt):
+ counter = 1
+ seed = artifact.seed
+ while True:
+ if counter > 1:
+ base = "{}_{}_{}.png".format(seed, prompt, counter)
+ else:
+ base = "{}_{}.png".format(seed, prompt)
+ base = normalize_filename(base)
+ filename = "{}{}/{}".format(DOCUMENT_ROOT, IMAGES_PATH, base)
+ if not os.path.exists(filename):
+ break
+ counter = counter + 1
+
+ img = Image.open(io.BytesIO(artifact.binary))
+ img.save(filename)
+ return {"seed": seed, "prompt": prompt,
+ "filename": filename.removeprefix(DOCUMENT_ROOT)}
+
+def main() -> None:
+ os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
+
+ if not 'STABILITY_KEY' in os.environ:
+ print("Error: STABILITY_KEY must be set with a valid key")
+ sys.exit(1)
+
+ parser = argparse.ArgumentParser(
+ usage="%(prog)s [--seed=SEED] Prompt string here...",
+ description="Generate text-from-image using Stable Diffusion (via dreamstudio)"
+ )
+ parser.add_argument("--seed", nargs='?', type=int, required=False, default=None,
+ help="seed value to use (defaults to random)")
+ parser.add_argument("prompt", nargs='+',
+ help="text-to-image prompt")
+
+ args = parser.parse_args()
+
+ stability_api = client.StabilityInference(
+ key=os.environ['STABILITY_KEY'],
+ verbose=True,
+ engine="stable-diffusion-512-v2-0"
+ )
+
+ prompt_string = " ".join(args.prompt)
+
+ if args.seed:
+ answers = stability_api.generate(seed=args.seed, prompt=prompt_string)
+ else:
+ answers = stability_api.generate(prompt=prompt_string)
+
+ results = []
+ for resp in answers:
+ for artifact in resp.artifacts:
+ if artifact.finish_reason == generation.FILTER:
+ warnings.warn("Request tripped over API safety filters")
+ elif artifact.type == generation.ARTIFACT_IMAGE:
+ results.append(save_artifact(artifact, prompt_string))
+ print(json.dumps(results))
+
+if __name__ == "__main__":
+ main()