summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/infer.py b/infer.py
index 3b3b595..0a219a5 100644
--- a/infer.py
+++ b/infer.py
@@ -46,7 +46,7 @@ default_args = {
46 "model": "stabilityai/stable-diffusion-2-1", 46 "model": "stabilityai/stable-diffusion-2-1",
47 "precision": "fp32", 47 "precision": "fp32",
48 "ti_embeddings_dir": "embeddings_ti", 48 "ti_embeddings_dir": "embeddings_ti",
49 "lora_embedding": None, 49 "lora_embeddings_dir": None,
50 "output_dir": "output/inference", 50 "output_dir": "output/inference",
51 "config": None, 51 "config": None,
52} 52}
@@ -99,7 +99,7 @@ def create_args_parser():
99 type=str, 99 type=str,
100 ) 100 )
101 parser.add_argument( 101 parser.add_argument(
102 "--lora_embedding", 102 "--lora_embeddings_dir",
103 type=str, 103 type=str,
104 ) 104 )
105 parser.add_argument( 105 parser.add_argument(
@@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
341 341
342 342
343def create_pipeline(model, dtype): 343def create_pipeline(model, dtype):
344 print("Loading Stable Diffusion pipeline...") 344 print(f"Loading Stable Diffusion pipeline: {model}...")
345 345
346 tokenizer = MultiCLIPTokenizer.from_pretrained( 346 tokenizer = MultiCLIPTokenizer.from_pretrained(
347 model, subfolder="tokenizer", torch_dtype=dtype 347 model, subfolder="tokenizer", torch_dtype=dtype
@@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args):
435 negative_prompt=args.negative_prompt, 435 negative_prompt=args.negative_prompt,
436 height=args.height, 436 height=args.height,
437 width=args.width, 437 width=args.width,
438 generator=generator,
439 guidance_scale=args.guidance_scale,
438 num_images_per_prompt=args.batch_size, 440 num_images_per_prompt=args.batch_size,
439 num_inference_steps=args.steps, 441 num_inference_steps=args.steps,
440 guidance_scale=args.guidance_scale,
441 sag_scale=args.sag_scale, 442 sag_scale=args.sag_scale,
442 generator=generator,
443 image=init_image, 443 image=init_image,
444 strength=args.image_noise, 444 strength=args.image_noise,
445 ).images 445 ).images
@@ -527,8 +527,8 @@ def main():
527 527
528 pipeline = create_pipeline(args.model, dtype) 528 pipeline = create_pipeline(args.model, dtype)
529 529
530 load_embeddings_dir(pipeline, args.ti_embeddings_dir) 530 # load_embeddings_dir(pipeline, args.ti_embeddings_dir)
531 load_lora(pipeline, args.lora_embedding) 531 # load_lora(pipeline, args.lora_embeddings_dir)
532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
533 533
534 cmd_parser = create_cmd_parser() 534 cmd_parser = create_cmd_parser()