summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 22:16:13 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 22:16:13 +0200
commita8a5abae42f6f42056cc27e0cf5313aab080c3a7 (patch)
tree32c163bbc58aa2f827c5ba5108df81dc14fbe130 /infer.py
parentIncorporate upstream changes (diff)
downloadtextual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.gz
textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.bz2
textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.zip
Various improvements, added inference script
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py121
1 files changed, 121 insertions, 0 deletions
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000..b9e9ff7
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,121 @@
1import argparse
2import datetime
3from pathlib import Path
4from torch import autocast
5from diffusers import StableDiffusionPipeline
6import torch
7import json
8from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler
9from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify
11from pipelines.stable_diffusion.no_check import NoCheck
12
13model_id = "path-to-your-trained-model"
14
15prompt = "A photo of sks dog in a bucket"
16
17
18def parse_args():
19 parser = argparse.ArgumentParser(
20 description="Simple example of a training script."
21 )
22 parser.add_argument(
23 "--model",
24 type=str,
25 default=None,
26 )
27 parser.add_argument(
28 "--prompt",
29 type=str,
30 default=None,
31 )
32 parser.add_argument(
33 "--batch_size",
34 type=int,
35 default=1,
36 )
37 parser.add_argument(
38 "--batch_num",
39 type=int,
40 default=50,
41 )
42 parser.add_argument(
43 "--steps",
44 type=int,
45 default=80,
46 )
47 parser.add_argument(
48 "--scale",
49 type=int,
50 default=7.5,
51 )
52 parser.add_argument(
53 "--seed",
54 type=int,
55 default=None,
56 )
57 parser.add_argument(
58 "--output_dir",
59 type=str,
60 default="inference",
61 )
62 parser.add_argument(
63 "--config",
64 type=str,
65 default=None,
66 )
67
68 args = parser.parse_args()
69 if args.config is not None:
70 with open(args.config, 'rt') as f:
71 args = parser.parse_args(
72 namespace=argparse.Namespace(**json.load(f)["args"]))
73
74 return args
75
76
77def main():
78 args = parse_args()
79
80 seed = args.seed or torch.random.seed()
81 generator = torch.Generator(device="cuda").manual_seed(seed)
82
83 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
84 output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}")
85 output_dir.mkdir(parents=True, exist_ok=True)
86
87 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
88 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
89 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
90 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
91 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16)
92
93 pipeline = StableDiffusionPipeline(
94 text_encoder=text_encoder,
95 vae=vae,
96 unet=unet,
97 tokenizer=tokenizer,
98 scheduler=PNDMScheduler(
99 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
100 ),
101 safety_checker=NoCheck(),
102 feature_extractor=feature_extractor
103 )
104 pipeline.enable_attention_slicing()
105 pipeline.to("cuda")
106
107 with autocast("cuda"):
108 for i in range(args.batch_num):
109 images = pipeline(
110 [args.prompt] * args.batch_size,
111 num_inference_steps=args.steps,
112 guidance_scale=args.scale,
113 generator=generator,
114 ).images
115
116 for j, image in enumerate(images):
117 image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg"))
118
119
120if __name__ == "__main__":
121 main()