summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 08:43:34 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 08:43:34 +0100
commit64c79cc3e7fad49131f90fbb0648b6d5587563e5 (patch)
tree372bb09a8c952bd28a8da069659da26ce2c99894 /infer.py
parentFix sample steps (diff)
downloadtextual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.gz
textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.bz2
textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.zip
Various updated; shuffle prompt content during training
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py27
1 files changed, 22 insertions, 5 deletions
diff --git a/infer.py b/infer.py
index 30e11cf..e3fa9e5 100644
--- a/infer.py
+++ b/infer.py
@@ -8,7 +8,18 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler 11from diffusers import (
12 AutoencoderKL,
13 UNet2DConditionModel,
14 PNDMScheduler,
15 DPMSolverMultistepScheduler,
16 DPMSolverSinglestepScheduler,
17 DDIMScheduler,
18 LMSDiscreteScheduler,
19 EulerAncestralDiscreteScheduler,
20 KDPM2DiscreteScheduler,
21 KDPM2AncestralDiscreteScheduler
22)
12from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 24from slugify import slugify
14 25
@@ -29,7 +40,7 @@ default_args = {
29 40
30 41
31default_cmds = { 42default_cmds = {
32 "scheduler": "dpmpp", 43 "scheduler": "dpmsm",
33 "prompt": None, 44 "prompt": None,
34 "negative_prompt": None, 45 "negative_prompt": None,
35 "image": None, 46 "image": None,
@@ -38,7 +49,7 @@ default_cmds = {
38 "height": 512, 49 "height": 512,
39 "batch_size": 1, 50 "batch_size": 1,
40 "batch_num": 1, 51 "batch_num": 1,
41 "steps": 50, 52 "steps": 30,
42 "guidance_scale": 7.0, 53 "guidance_scale": 7.0,
43 "seed": None, 54 "seed": None,
44 "config": None, 55 "config": None,
@@ -90,7 +101,7 @@ def create_cmd_parser():
90 parser.add_argument( 101 parser.add_argument(
91 "--scheduler", 102 "--scheduler",
92 type=str, 103 type=str,
93 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], 104 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
94 ) 105 )
95 parser.add_argument( 106 parser.add_argument(
96 "--prompt", 107 "--prompt",
@@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args):
252 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) 263 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
253 elif args.scheduler == "ddim": 264 elif args.scheduler == "ddim":
254 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) 265 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
255 elif args.scheduler == "dpmpp": 266 elif args.scheduler == "dpmsm":
256 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 267 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
268 elif args.scheduler == "dpmss":
269 pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)
257 elif args.scheduler == "euler_a": 270 elif args.scheduler == "euler_a":
258 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 271 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
272 elif args.scheduler == "kdpm2":
273 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
274 elif args.scheduler == "kdpm2_a":
275 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
259 276
260 with torch.autocast("cuda"), torch.inference_mode(): 277 with torch.autocast("cuda"), torch.inference_mode():
261 for i in range(args.batch_num): 278 for i in range(args.batch_num):