summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
committerVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
commit5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch)
treea3461a4f1a04fba52ec8fde8b7b07095c7422d85 /infer.py
parentAdded custom SD pipeline + euler_a scheduler (diff)
downloadtextual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip
Made inference script interactive
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py154
1 files changed, 109 insertions, 45 deletions
diff --git a/infer.py b/infer.py
index de3d792..40720ea 100644
--- a/infer.py
+++ b/infer.py
@@ -1,18 +1,21 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import sys
5import shlex
6import cmd
4from pathlib import Path 7from pathlib import Path
5from torch import autocast 8from torch import autocast
6import torch 9import torch
7import json 10import json
8from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
9from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 12from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify 13from slugify import slugify
11from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion 14from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
12from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
13 16
14 17
15def parse_args(): 18def create_args_parser():
16 parser = argparse.ArgumentParser( 19 parser = argparse.ArgumentParser(
17 description="Simple example of a training script." 20 description="Simple example of a training script."
18 ) 21 )
@@ -22,6 +25,30 @@ def parse_args():
22 default=None, 25 default=None,
23 ) 26 )
24 parser.add_argument( 27 parser.add_argument(
28 "--scheduler",
29 type=str,
30 choices=["plms", "ddim", "klms", "euler_a"],
31 default="euler_a",
32 )
33 parser.add_argument(
34 "--output_dir",
35 type=str,
36 default="output/inference",
37 )
38 parser.add_argument(
39 "--config",
40 type=str,
41 default=None,
42 )
43
44 return parser
45
46
47def create_cmd_parser():
48 parser = argparse.ArgumentParser(
49 description="Simple example of a training script."
50 )
51 parser.add_argument(
25 "--prompt", 52 "--prompt",
26 type=str, 53 type=str,
27 default=None, 54 default=None,
@@ -49,28 +76,17 @@ def parse_args():
49 parser.add_argument( 76 parser.add_argument(
50 "--batch_num", 77 "--batch_num",
51 type=int, 78 type=int,
52 default=50, 79 default=1,
53 ) 80 )
54 parser.add_argument( 81 parser.add_argument(
55 "--steps", 82 "--steps",
56 type=int, 83 type=int,
57 default=120, 84 default=70,
58 )
59 parser.add_argument(
60 "--scheduler",
61 type=str,
62 choices=["plms", "ddim", "klms", "euler_a"],
63 default="euler_a",
64 ) 85 )
65 parser.add_argument( 86 parser.add_argument(
66 "--guidance_scale", 87 "--guidance_scale",
67 type=int, 88 type=int,
68 default=7.5, 89 default=7,
69 )
70 parser.add_argument(
71 "--clip_guidance_scale",
72 type=int,
73 default=100,
74 ) 90 )
75 parser.add_argument( 91 parser.add_argument(
76 "--seed", 92 "--seed",
@@ -78,21 +94,21 @@ def parse_args():
78 default=torch.random.seed(), 94 default=torch.random.seed(),
79 ) 95 )
80 parser.add_argument( 96 parser.add_argument(
81 "--output_dir",
82 type=str,
83 default="output/inference",
84 )
85 parser.add_argument(
86 "--config", 97 "--config",
87 type=str, 98 type=str,
88 default=None, 99 default=None,
89 ) 100 )
90 101
91 args = parser.parse_args() 102 return parser
103
104
105def run_parser(parser, input=None):
106 args = parser.parse_known_args(input)[0]
107
92 if args.config is not None: 108 if args.config is not None:
93 with open(args.config, 'rt') as f: 109 with open(args.config, 'rt') as f:
94 args = parser.parse_args( 110 args = parser.parse_known_args(
95 namespace=argparse.Namespace(**json.load(f)["args"])) 111 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
96 112
97 return args 113 return args
98 114
@@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}):
104 json.dump(info, f, indent=4) 120 json.dump(info, f, indent=4)
105 121
106 122
107def gen(args, output_dir): 123def create_pipeline(model, scheduler, dtype=torch.bfloat16):
108 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) 124 print("Loading Stable Diffusion pipeline...")
109 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
110 clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
111 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
112 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
113 feature_extractor = CLIPFeatureExtractor.from_pretrained(
114 "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
115 125
116 if args.scheduler == "plms": 126 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype)
127 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype)
128 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype)
129 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype)
130 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype)
131
132 if scheduler == "plms":
117 scheduler = PNDMScheduler( 133 scheduler = PNDMScheduler(
118 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 134 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
119 ) 135 )
120 elif args.scheduler == "klms": 136 elif scheduler == "klms":
121 scheduler = LMSDiscreteScheduler( 137 scheduler = LMSDiscreteScheduler(
122 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 138 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
123 ) 139 )
124 elif args.scheduler == "ddim": 140 elif scheduler == "ddim":
125 scheduler = DDIMScheduler( 141 scheduler = DDIMScheduler(
126 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 142 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
127 ) 143 )
@@ -135,13 +151,24 @@ def gen(args, output_dir):
135 vae=vae, 151 vae=vae,
136 unet=unet, 152 unet=unet,
137 tokenizer=tokenizer, 153 tokenizer=tokenizer,
138 clip_model=clip_model,
139 scheduler=scheduler, 154 scheduler=scheduler,
140 feature_extractor=feature_extractor 155 feature_extractor=feature_extractor
141 ) 156 )
142 pipeline.enable_attention_slicing() 157 pipeline.enable_attention_slicing()
143 pipeline.to("cuda") 158 pipeline.to("cuda")
144 159
160 print("Pipeline loaded.")
161
162 return pipeline
163
164
165def generate(output_dir, pipeline, args):
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True)
169
170 save_args(output_dir, args)
171
145 with autocast("cuda"): 172 with autocast("cuda"):
146 for i in range(args.batch_num): 173 for i in range(args.batch_num):
147 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 174 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
@@ -152,7 +179,6 @@ def gen(args, output_dir):
152 negative_prompt=args.negative_prompt, 179 negative_prompt=args.negative_prompt,
153 num_inference_steps=args.steps, 180 num_inference_steps=args.steps,
154 guidance_scale=args.guidance_scale, 181 guidance_scale=args.guidance_scale,
155 clip_guidance_scale=args.clip_guidance_scale,
156 generator=generator, 182 generator=generator,
157 ).images 183 ).images
158 184
@@ -160,18 +186,56 @@ def gen(args, output_dir):
160 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 186 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
161 187
162 188
163def main(): 189class CmdParse(cmd.Cmd):
164 args = parse_args() 190 prompt = 'dream> '
191 commands = []
165 192
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 193 def __init__(self, output_dir, pipeline, parser):
167 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") 194 super().__init__()
168 output_dir.mkdir(parents=True, exist_ok=True)
169 195
170 save_args(output_dir, args) 196 self.output_dir = output_dir
197 self.pipeline = pipeline
198 self.parser = parser
199
200 def default(self, line):
201 line = line.replace("'", "\\'")
202
203 try:
204 elements = shlex.split(line)
205 except ValueError as e:
206 print(str(e))
207
208 if elements[0] == 'q':
209 return True
210
211 try:
212 args = run_parser(self.parser, elements)
213 except SystemExit:
214 self.parser.print_help()
215
216 if len(args.prompt) == 0:
217 print('Try again with a prompt!')
218
219 try:
220 generate(self.output_dir, self.pipeline, args)
221 except KeyboardInterrupt:
222 print('Generation cancelled.')
223
224 def do_exit(self, line):
225 return True
226
227
228def main():
229 logging.basicConfig(stream=sys.stdout, level=logging.WARN)
171 230
172 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 231 args_parser = create_args_parser()
232 args = run_parser(args_parser)
233 output_dir = Path(args.output_dir)
173 234
174 gen(args, output_dir) 235 pipeline = create_pipeline(args.model, args.scheduler)
236 cmd_parser = create_cmd_parser()
237 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
238 cmd_prompt.cmdloop()
175 239
176 240
177if __name__ == "__main__": 241if __name__ == "__main__":