diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 15 |
1 files changed, 12 insertions, 3 deletions
@@ -29,6 +29,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
29 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
32 | from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | ||
32 | from util import load_config, load_embeddings_from_dir | 33 | from util import load_config, load_embeddings_from_dir |
33 | 34 | ||
34 | 35 | ||
@@ -61,6 +62,7 @@ default_cmds = { | |||
61 | "batch_num": 1, | 62 | "batch_num": 1, |
62 | "steps": 30, | 63 | "steps": 30, |
63 | "guidance_scale": 7.0, | 64 | "guidance_scale": 7.0, |
65 | "sag_scale": 0.75, | ||
64 | "lora_scale": 0.5, | 66 | "lora_scale": 0.5, |
65 | "seed": None, | 67 | "seed": None, |
66 | "config": None, | 68 | "config": None, |
@@ -122,7 +124,7 @@ def create_cmd_parser(): | |||
122 | parser.add_argument( | 124 | parser.add_argument( |
123 | "--scheduler", | 125 | "--scheduler", |
124 | type=str, | 126 | type=str, |
125 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 127 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], |
126 | ) | 128 | ) |
127 | parser.add_argument( | 129 | parser.add_argument( |
128 | "--template", | 130 | "--template", |
@@ -175,6 +177,10 @@ def create_cmd_parser(): | |||
175 | type=float, | 177 | type=float, |
176 | ) | 178 | ) |
177 | parser.add_argument( | 179 | parser.add_argument( |
180 | "--sag_scale", | ||
181 | type=float, | ||
182 | ) | ||
183 | parser.add_argument( | ||
178 | "--lora_scale", | 184 | "--lora_scale", |
179 | type=float, | 185 | type=float, |
180 | ) | 186 | ) |
@@ -304,6 +310,8 @@ def generate(output_dir: Path, pipeline, args): | |||
304 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | 310 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) |
305 | elif args.scheduler == "kdpm2_a": | 311 | elif args.scheduler == "kdpm2_a": |
306 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 312 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
313 | elif args.scheduler == "unipc": | ||
314 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | ||
307 | 315 | ||
308 | for i in range(args.batch_num): | 316 | for i in range(args.batch_num): |
309 | pipeline.set_progress_bar_config( | 317 | pipeline.set_progress_bar_config( |
@@ -322,10 +330,11 @@ def generate(output_dir: Path, pipeline, args): | |||
322 | num_images_per_prompt=args.batch_size, | 330 | num_images_per_prompt=args.batch_size, |
323 | num_inference_steps=args.steps, | 331 | num_inference_steps=args.steps, |
324 | guidance_scale=args.guidance_scale, | 332 | guidance_scale=args.guidance_scale, |
333 | sag_scale=args.sag_scale, | ||
325 | generator=generator, | 334 | generator=generator, |
326 | image=init_image, | 335 | image=init_image, |
327 | strength=args.image_noise, | 336 | strength=args.image_noise, |
328 | cross_attention_kwargs={"scale": args.lora_scale}, | 337 | # cross_attention_kwargs={"scale": args.lora_scale}, |
329 | ).images | 338 | ).images |
330 | 339 | ||
331 | for j, image in enumerate(images): | 340 | for j, image in enumerate(images): |
@@ -408,7 +417,7 @@ def main(): | |||
408 | pipeline = create_pipeline(args.model, dtype) | 417 | pipeline = create_pipeline(args.model, dtype) |
409 | 418 | ||
410 | load_embeddings(pipeline, args.ti_embeddings_dir) | 419 | load_embeddings(pipeline, args.ti_embeddings_dir) |
411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 420 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
412 | 421 | ||
413 | cmd_parser = create_cmd_parser() | 422 | cmd_parser = create_cmd_parser() |
414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 423 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |