summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
committerVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
commitd673760fc671d665aadae3b032f8e99f21ab986d (patch)
tree7c14a998742b19ddecac6ee25a669892b41c305e /infer.py
parentUpdate (diff)
downloadtextual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.gz
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.bz2
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.zip
Integrated WIP UniPC scheduler
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index aa75ee5..329c60b 100644
--- a/infer.py
+++ b/infer.py
@@ -29,6 +29,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
29from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
32from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
32from util import load_config, load_embeddings_from_dir 33from util import load_config, load_embeddings_from_dir
33 34
34 35
@@ -61,6 +62,7 @@ default_cmds = {
61 "batch_num": 1, 62 "batch_num": 1,
62 "steps": 30, 63 "steps": 30,
63 "guidance_scale": 7.0, 64 "guidance_scale": 7.0,
65 "sag_scale": 0.75,
64 "lora_scale": 0.5, 66 "lora_scale": 0.5,
65 "seed": None, 67 "seed": None,
66 "config": None, 68 "config": None,
@@ -122,7 +124,7 @@ def create_cmd_parser():
122 parser.add_argument( 124 parser.add_argument(
123 "--scheduler", 125 "--scheduler",
124 type=str, 126 type=str,
125 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 127 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"],
126 ) 128 )
127 parser.add_argument( 129 parser.add_argument(
128 "--template", 130 "--template",
@@ -175,6 +177,10 @@ def create_cmd_parser():
175 type=float, 177 type=float,
176 ) 178 )
177 parser.add_argument( 179 parser.add_argument(
180 "--sag_scale",
181 type=float,
182 )
183 parser.add_argument(
178 "--lora_scale", 184 "--lora_scale",
179 type=float, 185 type=float,
180 ) 186 )
@@ -304,6 +310,8 @@ def generate(output_dir: Path, pipeline, args):
304 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) 310 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
305 elif args.scheduler == "kdpm2_a": 311 elif args.scheduler == "kdpm2_a":
306 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 312 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
313 elif args.scheduler == "unipc":
314 pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
307 315
308 for i in range(args.batch_num): 316 for i in range(args.batch_num):
309 pipeline.set_progress_bar_config( 317 pipeline.set_progress_bar_config(
@@ -322,10 +330,11 @@ def generate(output_dir: Path, pipeline, args):
322 num_images_per_prompt=args.batch_size, 330 num_images_per_prompt=args.batch_size,
323 num_inference_steps=args.steps, 331 num_inference_steps=args.steps,
324 guidance_scale=args.guidance_scale, 332 guidance_scale=args.guidance_scale,
333 sag_scale=args.sag_scale,
325 generator=generator, 334 generator=generator,
326 image=init_image, 335 image=init_image,
327 strength=args.image_noise, 336 strength=args.image_noise,
328 cross_attention_kwargs={"scale": args.lora_scale}, 337 # cross_attention_kwargs={"scale": args.lora_scale},
329 ).images 338 ).images
330 339
331 for j, image in enumerate(images): 340 for j, image in enumerate(images):
@@ -408,7 +417,7 @@ def main():
408 pipeline = create_pipeline(args.model, dtype) 417 pipeline = create_pipeline(args.model, dtype)
409 418
410 load_embeddings(pipeline, args.ti_embeddings_dir) 419 load_embeddings(pipeline, args.ti_embeddings_dir)
411 pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 420 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
412 421
413 cmd_parser = create_cmd_parser() 422 cmd_parser = create_cmd_parser()
414 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) 423 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser)