From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- infer.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 2b07b21..42b4e2d 100644 --- a/infer.py +++ b/infer.py @@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", - "ti_embeddings_dir": "embeddings", + "ti_embeddings_dir": "embeddings_ti", + "lora_embeddings_dir": "embeddings_lora", "output_dir": "output/inference", "config": None, } @@ -60,6 +61,7 @@ default_cmds = { "batch_num": 1, "steps": 30, "guidance_scale": 7.0, + "lora_scale": 0.5, "seed": None, "config": None, } @@ -91,6 +93,10 @@ def create_args_parser(): "--ti_embeddings_dir", type=str, ) + parser.add_argument( + "--lora_embeddings_dir", + type=str, + ) parser.add_argument( "--output_dir", type=str, @@ -168,6 +174,10 @@ def create_cmd_parser(): "--guidance_scale", type=float, ) + parser.add_argument( + "--lora_scale", + type=float, + ) parser.add_argument( "--seed", type=int, @@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): generator=generator, image=init_image, strength=args.image_noise, + cross_attention_kwargs={"scale": args.lora_scale}, ).images for j, image in enumerate(images): @@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] - def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): + def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): super().__init__() self.output_dir = output_dir self.ti_embeddings_dir = ti_embeddings_dir + self.lora_embeddings_dir = lora_embeddings_dir self.pipeline = pipeline self.parser = parser @@ -394,9 +406,12 @@ def main(): dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] pipeline = create_pipeline(args.model, dtype) + load_embeddings(pipeline, args.ti_embeddings_dir) + pipeline.unet.load_attn_procs(args.lora_embeddings_dir) + cmd_parser = create_cmd_parser() - cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) + cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() -- cgit v1.2.3-54-g00ecf