diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 21 |
1 files changed, 18 insertions, 3 deletions
| @@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True | |||
| 39 | default_args = { | 39 | default_args = { |
| 40 | "model": "stabilityai/stable-diffusion-2-1", | 40 | "model": "stabilityai/stable-diffusion-2-1", |
| 41 | "precision": "fp32", | 41 | "precision": "fp32", |
| 42 | "ti_embeddings_dir": "embeddings", | 42 | "ti_embeddings_dir": "embeddings_ti", |
| 43 | "lora_embeddings_dir": "embeddings_lora", | ||
| 43 | "output_dir": "output/inference", | 44 | "output_dir": "output/inference", |
| 44 | "config": None, | 45 | "config": None, |
| 45 | } | 46 | } |
| @@ -60,6 +61,7 @@ default_cmds = { | |||
| 60 | "batch_num": 1, | 61 | "batch_num": 1, |
| 61 | "steps": 30, | 62 | "steps": 30, |
| 62 | "guidance_scale": 7.0, | 63 | "guidance_scale": 7.0, |
| 64 | "lora_scale": 0.5, | ||
| 63 | "seed": None, | 65 | "seed": None, |
| 64 | "config": None, | 66 | "config": None, |
| 65 | } | 67 | } |
| @@ -92,6 +94,10 @@ def create_args_parser(): | |||
| 92 | type=str, | 94 | type=str, |
| 93 | ) | 95 | ) |
| 94 | parser.add_argument( | 96 | parser.add_argument( |
| 97 | "--lora_embeddings_dir", | ||
| 98 | type=str, | ||
| 99 | ) | ||
| 100 | parser.add_argument( | ||
| 95 | "--output_dir", | 101 | "--output_dir", |
| 96 | type=str, | 102 | type=str, |
| 97 | ) | 103 | ) |
| @@ -169,6 +175,10 @@ def create_cmd_parser(): | |||
| 169 | type=float, | 175 | type=float, |
| 170 | ) | 176 | ) |
| 171 | parser.add_argument( | 177 | parser.add_argument( |
| 178 | "--lora_scale", | ||
| 179 | type=float, | ||
| 180 | ) | ||
| 181 | parser.add_argument( | ||
| 172 | "--seed", | 182 | "--seed", |
| 173 | type=int, | 183 | type=int, |
| 174 | ) | 184 | ) |
| @@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): | |||
| 315 | generator=generator, | 325 | generator=generator, |
| 316 | image=init_image, | 326 | image=init_image, |
| 317 | strength=args.image_noise, | 327 | strength=args.image_noise, |
| 328 | cross_attention_kwargs={"scale": args.lora_scale}, | ||
| 318 | ).images | 329 | ).images |
| 319 | 330 | ||
| 320 | for j, image in enumerate(images): | 331 | for j, image in enumerate(images): |
| @@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): | |||
| 334 | prompt = 'dream> ' | 345 | prompt = 'dream> ' |
| 335 | commands = [] | 346 | commands = [] |
| 336 | 347 | ||
| 337 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): | 348 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): |
| 338 | super().__init__() | 349 | super().__init__() |
| 339 | 350 | ||
| 340 | self.output_dir = output_dir | 351 | self.output_dir = output_dir |
| 341 | self.ti_embeddings_dir = ti_embeddings_dir | 352 | self.ti_embeddings_dir = ti_embeddings_dir |
| 353 | self.lora_embeddings_dir = lora_embeddings_dir | ||
| 342 | self.pipeline = pipeline | 354 | self.pipeline = pipeline |
| 343 | self.parser = parser | 355 | self.parser = parser |
| 344 | 356 | ||
| @@ -394,9 +406,12 @@ def main(): | |||
| 394 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 406 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 395 | 407 | ||
| 396 | pipeline = create_pipeline(args.model, dtype) | 408 | pipeline = create_pipeline(args.model, dtype) |
| 409 | |||
| 397 | load_embeddings(pipeline, args.ti_embeddings_dir) | 410 | load_embeddings(pipeline, args.ti_embeddings_dir) |
| 411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | ||
| 412 | |||
| 398 | cmd_parser = create_cmd_parser() | 413 | cmd_parser = create_cmd_parser() |
| 399 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) | 414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |
| 400 | cmd_prompt.cmdloop() | 415 | cmd_prompt.cmdloop() |
| 401 | 416 | ||
| 402 | 417 | ||
