summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py56
1 files changed, 44 insertions, 12 deletions
diff --git a/infer.py b/infer.py
index f88245a..c4d1e0d 100644
--- a/infer.py
+++ b/infer.py
@@ -45,6 +45,7 @@ default_args = {
45 45
46 46
47default_cmds = { 47default_cmds = {
48 "project": "",
48 "scheduler": "dpmsm", 49 "scheduler": "dpmsm",
49 "prompt": None, 50 "prompt": None,
50 "negative_prompt": None, 51 "negative_prompt": None,
@@ -104,6 +105,12 @@ def create_cmd_parser():
104 description="Simple example of a training script." 105 description="Simple example of a training script."
105 ) 106 )
106 parser.add_argument( 107 parser.add_argument(
108 "--project",
109 type=str,
110 default=None,
111 help="The name of the current project.",
112 )
113 parser.add_argument(
107 "--scheduler", 114 "--scheduler",
108 type=str, 115 type=str,
109 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 116 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
@@ -184,7 +191,16 @@ def save_args(basepath, args, extra={}):
184 json.dump(info, f, indent=4) 191 json.dump(info, f, indent=4)
185 192
186 193
187def create_pipeline(model, embeddings_dir, dtype): 194def load_embeddings(pipeline, embeddings_dir):
195 added_tokens = load_embeddings_from_dir(
196 pipeline.tokenizer,
197 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir)
199 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
201
202
203def create_pipeline(model, dtype):
188 print("Loading Stable Diffusion pipeline...") 204 print("Loading Stable Diffusion pipeline...")
189 205
190 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 206 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -193,10 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype):
193 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 209 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
194 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 210 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
195 211
196 embeddings = patch_managed_embeddings(text_encoder) 212 patch_managed_embeddings(text_encoder)
197 added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir))
198
199 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
200 213
201 pipeline = VlpnStableDiffusion( 214 pipeline = VlpnStableDiffusion(
202 text_encoder=text_encoder, 215 text_encoder=text_encoder,
@@ -220,7 +233,14 @@ def generate(output_dir, pipeline, args):
220 args.prompt = [args.prompt] 233 args.prompt = [args.prompt]
221 234
222 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 235 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
223 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 236 use_subdirs = len(args.prompt) != 1
237 if use_subdirs:
238 if len(args.project) != 0:
239 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}")
240 else:
241 output_dir = output_dir.joinpath(now)
242 else:
243 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
224 output_dir.mkdir(parents=True, exist_ok=True) 244 output_dir.mkdir(parents=True, exist_ok=True)
225 245
226 args.seed = args.seed or torch.random.seed() 246 args.seed = args.seed or torch.random.seed()
@@ -257,7 +277,8 @@ def generate(output_dir, pipeline, args):
257 dynamic_ncols=True 277 dynamic_ncols=True
258 ) 278 )
259 279
260 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 280 seed = args.seed + i
281 generator = torch.Generator(device="cuda").manual_seed(seed)
261 images = pipeline( 282 images = pipeline(
262 prompt=args.prompt, 283 prompt=args.prompt,
263 negative_prompt=args.negative_prompt, 284 negative_prompt=args.negative_prompt,
@@ -272,8 +293,13 @@ def generate(output_dir, pipeline, args):
272 ).images 293 ).images
273 294
274 for j, image in enumerate(images): 295 for j, image in enumerate(images):
275 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) 296 image_dir = output_dir
276 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) 297 if use_subdirs:
298 idx = j % len(args.prompt)
299 image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100])
300 image_dir.mkdir(parents=True, exist_ok=True)
301 image.save(image_dir.joinpath(f"{seed}_{j}.png"))
302 image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85)
277 303
278 if torch.cuda.is_available(): 304 if torch.cuda.is_available():
279 torch.cuda.empty_cache() 305 torch.cuda.empty_cache()
@@ -283,10 +309,11 @@ class CmdParse(cmd.Cmd):
283 prompt = 'dream> ' 309 prompt = 'dream> '
284 commands = [] 310 commands = []
285 311
286 def __init__(self, output_dir, pipeline, parser): 312 def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser):
287 super().__init__() 313 super().__init__()
288 314
289 self.output_dir = output_dir 315 self.output_dir = output_dir
316 self.ti_embeddings_dir = ti_embeddings_dir
290 self.pipeline = pipeline 317 self.pipeline = pipeline
291 self.parser = parser 318 self.parser = parser
292 319
@@ -302,6 +329,10 @@ class CmdParse(cmd.Cmd):
302 if elements[0] == 'q': 329 if elements[0] == 'q':
303 return True 330 return True
304 331
332 if elements[0] == 'reload_embeddings':
333 load_embeddings(self.pipeline, self.ti_embeddings_dir)
334 return
335
305 try: 336 try:
306 args = run_parser(self.parser, default_cmds, elements) 337 args = run_parser(self.parser, default_cmds, elements)
307 338
@@ -337,9 +368,10 @@ def main():
337 output_dir = Path(args.output_dir) 368 output_dir = Path(args.output_dir)
338 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 369 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
339 370
340 pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) 371 pipeline = create_pipeline(args.model, dtype)
372 load_embeddings(pipeline, args.ti_embeddings_dir)
341 cmd_parser = create_cmd_parser() 373 cmd_parser = create_cmd_parser()
342 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 374 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser)
343 cmd_prompt.cmdloop() 375 cmd_prompt.cmdloop()
344 376
345 377