From d673760fc671d665aadae3b032f8e99f21ab986d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 16 Feb 2023 09:16:05 +0100 Subject: Integrated WIP UniPC scheduler --- infer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index aa75ee5..329c60b 100644 --- a/infer.py +++ b/infer.py @@ -29,6 +29,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from util import load_config, load_embeddings_from_dir @@ -61,6 +62,7 @@ default_cmds = { "batch_num": 1, "steps": 30, "guidance_scale": 7.0, + "sag_scale": 0.75, "lora_scale": 0.5, "seed": None, "config": None, @@ -122,7 +124,7 @@ def create_cmd_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], + choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], ) parser.add_argument( "--template", @@ -174,6 +176,10 @@ def create_cmd_parser(): "--guidance_scale", type=float, ) + parser.add_argument( + "--sag_scale", + type=float, + ) parser.add_argument( "--lora_scale", type=float, @@ -304,6 +310,8 @@ def generate(output_dir: Path, pipeline, args): pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "kdpm2_a": pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "unipc": + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) for i in range(args.batch_num): pipeline.set_progress_bar_config( @@ -322,10 +330,11 @@ def generate(output_dir: Path, pipeline, args): num_images_per_prompt=args.batch_size, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, + sag_scale=args.sag_scale, generator=generator, image=init_image, strength=args.image_noise, - cross_attention_kwargs={"scale": args.lora_scale}, + # cross_attention_kwargs={"scale": args.lora_scale}, ).images for j, image in enumerate(images): @@ -408,7 +417,7 @@ def main(): pipeline = create_pipeline(args.model, dtype) load_embeddings(pipeline, args.ti_embeddings_dir) - pipeline.unet.load_attn_procs(args.lora_embeddings_dir) + # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) -- cgit v1.2.3-54-g00ecf