]> git.cworth.org Git - zombocom-ai/blob - generate-image.py
Client: Add received images to the DOM
[zombocom-ai] / generate-image.py
1 #!/usr/bin/env python3
2
3 import os
4 import sys
5 import io
6 import warnings
7 import argparse
8 import json
9 from PIL import Image
10 from stability_sdk import client
11 import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
12
13 DOCUMENT_ROOT = "/srv/cworth.org/zombocom"
14 IMAGES_PATH = "/images"
15
16 def normalize_filename(fn):
17     valid = "-_."
18     out = ""
19     for c in fn:
20       if str.isalpha(c) or str.isdigit(c) or (c in valid):
21         out += c
22       else:
23         out += "_"
24     return out
25
26 def save_artifact(artifact, prompt):
27     counter = 1
28     seed = artifact.seed
29     while True:
30         if counter > 1:
31             base = "{}_{}_{}.png".format(seed, prompt, counter)
32         else:
33             base = "{}_{}.png".format(seed, prompt)
34         base = normalize_filename(base)
35         filename = "{}{}/{}".format(DOCUMENT_ROOT, IMAGES_PATH, base)
36         if not os.path.exists(filename):
37             break
38         counter = counter + 1
39
40     img = Image.open(io.BytesIO(artifact.binary))
41     img.save(filename)
42     return {"seed": seed, "prompt": prompt,
43             "filename": filename.removeprefix(DOCUMENT_ROOT)}
44
45 def main() -> None:
46     os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
47
48     if not 'STABILITY_KEY' in os.environ:
49         print("Error: STABILITY_KEY must be set with a valid key")
50         sys.exit(1)
51
52     parser = argparse.ArgumentParser(
53         usage="%(prog)s [--seed=SEED] Prompt string here...",
54         description="Generate text-from-image using Stable Diffusion (via dreamstudio)"
55     )
56     parser.add_argument("--seed", nargs='?', type=int, required=False, default=None,
57                         help="seed value to use (defaults to random)")
58     parser.add_argument("prompt", nargs='+',
59                         help="text-to-image prompt")
60
61     args = parser.parse_args()
62
63     stability_api = client.StabilityInference(
64         key=os.environ['STABILITY_KEY'],
65         verbose=True,
66         engine="stable-diffusion-512-v2-0"
67     )
68
69     prompt_string = " ".join(args.prompt)
70
71     if args.seed:
72         answers = stability_api.generate(seed=args.seed, prompt=prompt_string)
73     else:
74         answers = stability_api.generate(prompt=prompt_string)
75
76     results = []
77     for resp in answers:
78         for artifact in resp.artifacts:
79             if artifact.finish_reason == generation.FILTER:
80                 warnings.warn("Request tripped over API safety filters")
81             elif artifact.type == generation.ARTIFACT_IMAGE:
82                 results.append(save_artifact(artifact, prompt_string))
83     print(json.dumps(results))
84
85 if __name__ == "__main__":
86     main()