diff options
author | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
commit | 7ccd4614a56cfd6ecacba85605f338593f1059f0 (patch) | |
tree | fa9882b256c752705bc42229bac4e00ed7088643 /infer.py | |
parent | Restored LR finder (diff) | |
download | textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.gz textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.bz2 textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.zip |
Add Lora
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 21 |
1 files changed, 18 insertions, 3 deletions
@@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True | |||
39 | default_args = { | 39 | default_args = { |
40 | "model": "stabilityai/stable-diffusion-2-1", | 40 | "model": "stabilityai/stable-diffusion-2-1", |
41 | "precision": "fp32", | 41 | "precision": "fp32", |
42 | "ti_embeddings_dir": "embeddings", | 42 | "ti_embeddings_dir": "embeddings_ti", |
43 | "lora_embeddings_dir": "embeddings_lora", | ||
43 | "output_dir": "output/inference", | 44 | "output_dir": "output/inference", |
44 | "config": None, | 45 | "config": None, |
45 | } | 46 | } |
@@ -60,6 +61,7 @@ default_cmds = { | |||
60 | "batch_num": 1, | 61 | "batch_num": 1, |
61 | "steps": 30, | 62 | "steps": 30, |
62 | "guidance_scale": 7.0, | 63 | "guidance_scale": 7.0, |
64 | "lora_scale": 0.5, | ||
63 | "seed": None, | 65 | "seed": None, |
64 | "config": None, | 66 | "config": None, |
65 | } | 67 | } |
@@ -92,6 +94,10 @@ def create_args_parser(): | |||
92 | type=str, | 94 | type=str, |
93 | ) | 95 | ) |
94 | parser.add_argument( | 96 | parser.add_argument( |
97 | "--lora_embeddings_dir", | ||
98 | type=str, | ||
99 | ) | ||
100 | parser.add_argument( | ||
95 | "--output_dir", | 101 | "--output_dir", |
96 | type=str, | 102 | type=str, |
97 | ) | 103 | ) |
@@ -169,6 +175,10 @@ def create_cmd_parser(): | |||
169 | type=float, | 175 | type=float, |
170 | ) | 176 | ) |
171 | parser.add_argument( | 177 | parser.add_argument( |
178 | "--lora_scale", | ||
179 | type=float, | ||
180 | ) | ||
181 | parser.add_argument( | ||
172 | "--seed", | 182 | "--seed", |
173 | type=int, | 183 | type=int, |
174 | ) | 184 | ) |
@@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): | |||
315 | generator=generator, | 325 | generator=generator, |
316 | image=init_image, | 326 | image=init_image, |
317 | strength=args.image_noise, | 327 | strength=args.image_noise, |
328 | cross_attention_kwargs={"scale": args.lora_scale}, | ||
318 | ).images | 329 | ).images |
319 | 330 | ||
320 | for j, image in enumerate(images): | 331 | for j, image in enumerate(images): |
@@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): | |||
334 | prompt = 'dream> ' | 345 | prompt = 'dream> ' |
335 | commands = [] | 346 | commands = [] |
336 | 347 | ||
337 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): | 348 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): |
338 | super().__init__() | 349 | super().__init__() |
339 | 350 | ||
340 | self.output_dir = output_dir | 351 | self.output_dir = output_dir |
341 | self.ti_embeddings_dir = ti_embeddings_dir | 352 | self.ti_embeddings_dir = ti_embeddings_dir |
353 | self.lora_embeddings_dir = lora_embeddings_dir | ||
342 | self.pipeline = pipeline | 354 | self.pipeline = pipeline |
343 | self.parser = parser | 355 | self.parser = parser |
344 | 356 | ||
@@ -394,9 +406,12 @@ def main(): | |||
394 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 406 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
395 | 407 | ||
396 | pipeline = create_pipeline(args.model, dtype) | 408 | pipeline = create_pipeline(args.model, dtype) |
409 | |||
397 | load_embeddings(pipeline, args.ti_embeddings_dir) | 410 | load_embeddings(pipeline, args.ti_embeddings_dir) |
411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | ||
412 | |||
398 | cmd_parser = create_cmd_parser() | 413 | cmd_parser = create_cmd_parser() |
399 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) | 414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |
400 | cmd_prompt.cmdloop() | 415 | cmd_prompt.cmdloop() |
401 | 416 | ||
402 | 417 | ||