diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
| commit | 64c79cc3e7fad49131f90fbb0648b6d5587563e5 (patch) | |
| tree | 372bb09a8c952bd28a8da069659da26ce2c99894 /infer.py | |
| parent | Fix sample steps (diff) | |
| download | textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.gz textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.bz2 textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.zip | |
Various updated; shuffle prompt content during training
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 27 |
1 files changed, 22 insertions, 5 deletions
| @@ -8,7 +8,18 @@ from pathlib import Path | |||
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | from PIL import Image | 10 | from PIL import Image |
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler | 11 | from diffusers import ( |
| 12 | AutoencoderKL, | ||
| 13 | UNet2DConditionModel, | ||
| 14 | PNDMScheduler, | ||
| 15 | DPMSolverMultistepScheduler, | ||
| 16 | DPMSolverSinglestepScheduler, | ||
| 17 | DDIMScheduler, | ||
| 18 | LMSDiscreteScheduler, | ||
| 19 | EulerAncestralDiscreteScheduler, | ||
| 20 | KDPM2DiscreteScheduler, | ||
| 21 | KDPM2AncestralDiscreteScheduler | ||
| 22 | ) | ||
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 13 | from slugify import slugify | 24 | from slugify import slugify |
| 14 | 25 | ||
| @@ -29,7 +40,7 @@ default_args = { | |||
| 29 | 40 | ||
| 30 | 41 | ||
| 31 | default_cmds = { | 42 | default_cmds = { |
| 32 | "scheduler": "dpmpp", | 43 | "scheduler": "dpmsm", |
| 33 | "prompt": None, | 44 | "prompt": None, |
| 34 | "negative_prompt": None, | 45 | "negative_prompt": None, |
| 35 | "image": None, | 46 | "image": None, |
| @@ -38,7 +49,7 @@ default_cmds = { | |||
| 38 | "height": 512, | 49 | "height": 512, |
| 39 | "batch_size": 1, | 50 | "batch_size": 1, |
| 40 | "batch_num": 1, | 51 | "batch_num": 1, |
| 41 | "steps": 50, | 52 | "steps": 30, |
| 42 | "guidance_scale": 7.0, | 53 | "guidance_scale": 7.0, |
| 43 | "seed": None, | 54 | "seed": None, |
| 44 | "config": None, | 55 | "config": None, |
| @@ -90,7 +101,7 @@ def create_cmd_parser(): | |||
| 90 | parser.add_argument( | 101 | parser.add_argument( |
| 91 | "--scheduler", | 102 | "--scheduler", |
| 92 | type=str, | 103 | type=str, |
| 93 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | 104 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
| 94 | ) | 105 | ) |
| 95 | parser.add_argument( | 106 | parser.add_argument( |
| 96 | "--prompt", | 107 | "--prompt", |
| @@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): | |||
| 252 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | 263 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 253 | elif args.scheduler == "ddim": | 264 | elif args.scheduler == "ddim": |
| 254 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | 265 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
| 255 | elif args.scheduler == "dpmpp": | 266 | elif args.scheduler == "dpmsm": |
| 256 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | 267 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 268 | elif args.scheduler == "dpmss": | ||
| 269 | pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) | ||
| 257 | elif args.scheduler == "euler_a": | 270 | elif args.scheduler == "euler_a": |
| 258 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 271 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 272 | elif args.scheduler == "kdpm2": | ||
| 273 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 274 | elif args.scheduler == "kdpm2_a": | ||
| 275 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 259 | 276 | ||
| 260 | with torch.autocast("cuda"), torch.inference_mode(): | 277 | with torch.autocast("cuda"), torch.inference_mode(): |
| 261 | for i in range(args.batch_num): | 278 | for i in range(args.batch_num): |
