diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 14 |
1 files changed, 7 insertions, 7 deletions
@@ -46,7 +46,7 @@ default_args = { | |||
46 | "model": "stabilityai/stable-diffusion-2-1", | 46 | "model": "stabilityai/stable-diffusion-2-1", |
47 | "precision": "fp32", | 47 | "precision": "fp32", |
48 | "ti_embeddings_dir": "embeddings_ti", | 48 | "ti_embeddings_dir": "embeddings_ti", |
49 | "lora_embedding": None, | 49 | "lora_embeddings_dir": None, |
50 | "output_dir": "output/inference", | 50 | "output_dir": "output/inference", |
51 | "config": None, | 51 | "config": None, |
52 | } | 52 | } |
@@ -99,7 +99,7 @@ def create_args_parser(): | |||
99 | type=str, | 99 | type=str, |
100 | ) | 100 | ) |
101 | parser.add_argument( | 101 | parser.add_argument( |
102 | "--lora_embedding", | 102 | "--lora_embeddings_dir", |
103 | type=str, | 103 | type=str, |
104 | ) | 104 | ) |
105 | parser.add_argument( | 105 | parser.add_argument( |
@@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
341 | 341 | ||
342 | 342 | ||
343 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
344 | print("Loading Stable Diffusion pipeline...") | 344 | print(f"Loading Stable Diffusion pipeline: {model}...") |
345 | 345 | ||
346 | tokenizer = MultiCLIPTokenizer.from_pretrained( | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
347 | model, subfolder="tokenizer", torch_dtype=dtype | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
@@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args): | |||
435 | negative_prompt=args.negative_prompt, | 435 | negative_prompt=args.negative_prompt, |
436 | height=args.height, | 436 | height=args.height, |
437 | width=args.width, | 437 | width=args.width, |
438 | generator=generator, | ||
439 | guidance_scale=args.guidance_scale, | ||
438 | num_images_per_prompt=args.batch_size, | 440 | num_images_per_prompt=args.batch_size, |
439 | num_inference_steps=args.steps, | 441 | num_inference_steps=args.steps, |
440 | guidance_scale=args.guidance_scale, | ||
441 | sag_scale=args.sag_scale, | 442 | sag_scale=args.sag_scale, |
442 | generator=generator, | ||
443 | image=init_image, | 443 | image=init_image, |
444 | strength=args.image_noise, | 444 | strength=args.image_noise, |
445 | ).images | 445 | ).images |
@@ -527,8 +527,8 @@ def main(): | |||
527 | 527 | ||
528 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
529 | 529 | ||
530 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) | 530 | # load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
531 | load_lora(pipeline, args.lora_embedding) | 531 | # load_lora(pipeline, args.lora_embeddings_dir) |
532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
533 | 533 | ||
534 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |