]> git.cworth.org Git - zombocom-ai/blob - generate-image.py
Reword Coda's final message
[zombocom-ai] / generate-image.py
1 #!/usr/bin/env python3
2
3 import os
4 import sys
5 import io
6 import warnings
7 import argparse
8 import json
9 from PIL import Image
10 from stability_sdk import client
11 import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
12
13 DOCUMENT_ROOT = "/srv/cworth.org/zombocom"
14 IMAGES_PATH = "/images"
15
16 def normalize_filename(fn):
17     valid = "-_."
18     out = ""
19     for c in fn:
20       if str.isalpha(c) or str.isdigit(c) or (c in valid):
21         out += c
22       else:
23         out += "_"
24     return out
25
26 def save_artifact(artifact, prompt):
27     counter = 1
28     seed = artifact.seed
29     while True:
30         prompt_abbrev = prompt[:200]
31         if counter > 1:
32             base = "{}_{}_{}.png".format(seed, prompt_abbrev, counter)
33         else:
34             base = "{}_{}.png".format(seed, prompt_abbrev)
35         base = normalize_filename(base)
36         filename = "{}{}/{}".format(DOCUMENT_ROOT, IMAGES_PATH, base)
37         if not os.path.exists(filename):
38             break
39         counter = counter + 1
40
41     img = Image.open(io.BytesIO(artifact.binary))
42     img.save(filename)
43     return {"code": seed, "prompt": prompt,
44             "filename": filename.removeprefix(DOCUMENT_ROOT)}
45
46 def main() -> None:
47     os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
48
49     if not 'STABILITY_KEY' in os.environ:
50         print("Error: STABILITY_KEY must be set with a valid key")
51         sys.exit(1)
52
53     parser = argparse.ArgumentParser(
54         usage="%(prog)s [--seed=SEED] Prompt string here...",
55         description="Generate text-from-image using Stable Diffusion (via dreamstudio)"
56     )
57     parser.add_argument("--seed", nargs='?', type=int, required=False, default=None,
58                         help="seed value to use (defaults to random)")
59     parser.add_argument("prompt", nargs='+',
60                         help="text-to-image prompt")
61
62     args = parser.parse_args()
63
64     stability_api = client.StabilityInference(
65         key=os.environ['STABILITY_KEY'],
66         verbose=True,
67         engine="stable-diffusion-768-v2-1"
68     )
69
70     prompt_string = " ".join(args.prompt)
71
72     if args.seed:
73         answers = stability_api.generate(width=768, height=768,
74                                          seed=args.seed, prompt=prompt_string)
75     else:
76         answers = stability_api.generate(width=768, height=768,
77                                          prompt=prompt_string)
78
79     results = []
80     for resp in answers:
81         for artifact in resp.artifacts:
82             if artifact.finish_reason == generation.FILTER:
83                 warnings.warn("Request tripped over API safety filters")
84             elif artifact.type == generation.ARTIFACT_IMAGE:
85                 results.append(save_artifact(artifact, prompt_string))
86     print(json.dumps(results))
87
88 if __name__ == "__main__":
89     main()