summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
commit49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch)
tree3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /infer.py
parentFix seed, better progress bar, fix euler_a for batch size > 1 (diff)
downloadtextual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip
WIP: img2img
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py90
1 files changed, 67 insertions, 23 deletions
diff --git a/infer.py b/infer.py
index d917239..b440cb6 100644
--- a/infer.py
+++ b/infer.py
@@ -8,13 +8,47 @@ from pathlib import Path
8from torch import autocast 8from torch import autocast
9import torch 9import torch
10import json 10import json
11from PIL import Image
11from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 12from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
12from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 13from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
13from slugify import slugify 14from slugify import slugify
14from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
15from schedulers.scheduling_euler_a import EulerAScheduler 16from schedulers.scheduling_euler_a import EulerAScheduler
16 17
17 18
19default_args = {
20 "model": None,
21 "scheduler": "euler_a",
22 "output_dir": "output/inference",
23 "config": None,
24}
25
26
27default_cmds = {
28 "prompt": None,
29 "negative_prompt": None,
30 "image": None,
31 "image_strength": .7,
32 "width": 512,
33 "height": 512,
34 "batch_size": 1,
35 "batch_num": 1,
36 "steps": 50,
37 "guidance_scale": 7.0,
38 "seed": None,
39 "config": None,
40}
41
42
43def merge_dicts(d1, *args):
44 d1 = d1.copy()
45
46 for d in args:
47 d1.update({k: v for (k, v) in d.items() if v is not None})
48
49 return d1
50
51
18def create_args_parser(): 52def create_args_parser():
19 parser = argparse.ArgumentParser( 53 parser = argparse.ArgumentParser(
20 description="Simple example of a training script." 54 description="Simple example of a training script."
@@ -22,23 +56,19 @@ def create_args_parser():
22 parser.add_argument( 56 parser.add_argument(
23 "--model", 57 "--model",
24 type=str, 58 type=str,
25 default=None,
26 ) 59 )
27 parser.add_argument( 60 parser.add_argument(
28 "--scheduler", 61 "--scheduler",
29 type=str, 62 type=str,
30 choices=["plms", "ddim", "klms", "euler_a"], 63 choices=["plms", "ddim", "klms", "euler_a"],
31 default="euler_a",
32 ) 64 )
33 parser.add_argument( 65 parser.add_argument(
34 "--output_dir", 66 "--output_dir",
35 type=str, 67 type=str,
36 default="output/inference",
37 ) 68 )
38 parser.add_argument( 69 parser.add_argument(
39 "--config", 70 "--config",
40 type=str, 71 type=str,
41 default=None,
42 ) 72 )
43 73
44 return parser 74 return parser
@@ -51,66 +81,69 @@ def create_cmd_parser():
51 parser.add_argument( 81 parser.add_argument(
52 "--prompt", 82 "--prompt",
53 type=str, 83 type=str,
54 default=None,
55 ) 84 )
56 parser.add_argument( 85 parser.add_argument(
57 "--negative_prompt", 86 "--negative_prompt",
58 type=str, 87 type=str,
59 default=None, 88 )
89 parser.add_argument(
90 "--image",
91 type=str,
92 )
93 parser.add_argument(
94 "--image_strength",
95 type=float,
60 ) 96 )
61 parser.add_argument( 97 parser.add_argument(
62 "--width", 98 "--width",
63 type=int, 99 type=int,
64 default=512,
65 ) 100 )
66 parser.add_argument( 101 parser.add_argument(
67 "--height", 102 "--height",
68 type=int, 103 type=int,
69 default=512,
70 ) 104 )
71 parser.add_argument( 105 parser.add_argument(
72 "--batch_size", 106 "--batch_size",
73 type=int, 107 type=int,
74 default=1,
75 ) 108 )
76 parser.add_argument( 109 parser.add_argument(
77 "--batch_num", 110 "--batch_num",
78 type=int, 111 type=int,
79 default=1,
80 ) 112 )
81 parser.add_argument( 113 parser.add_argument(
82 "--steps", 114 "--steps",
83 type=int, 115 type=int,
84 default=70,
85 ) 116 )
86 parser.add_argument( 117 parser.add_argument(
87 "--guidance_scale", 118 "--guidance_scale",
88 type=int, 119 type=float,
89 default=7,
90 ) 120 )
91 parser.add_argument( 121 parser.add_argument(
92 "--seed", 122 "--seed",
93 type=int, 123 type=int,
94 default=None,
95 ) 124 )
96 parser.add_argument( 125 parser.add_argument(
97 "--config", 126 "--config",
98 type=str, 127 type=str,
99 default=None,
100 ) 128 )
101 129
102 return parser 130 return parser
103 131
104 132
105def run_parser(parser, input=None): 133def run_parser(parser, defaults, input=None):
106 args = parser.parse_known_args(input)[0] 134 args = parser.parse_known_args(input)[0]
135 conf_args = argparse.Namespace()
107 136
108 if args.config is not None: 137 if args.config is not None:
109 with open(args.config, 'rt') as f: 138 with open(args.config, 'rt') as f:
110 args = parser.parse_known_args( 139 conf_args = parser.parse_known_args(
111 namespace=argparse.Namespace(**json.load(f)["args"]))[0] 140 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
112 141
113 return args 142 res = defaults.copy()
143 for dict in [vars(conf_args), vars(args)]:
144 res.update({k: v for (k, v) in dict.items() if v is not None})
145
146 return argparse.Namespace(**res)
114 147
115 148
116def save_args(basepath, args, extra={}): 149def save_args(basepath, args, extra={}):
@@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16):
146 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 179 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
147 ) 180 )
148 181
149 pipeline = CLIPGuidedStableDiffusion( 182 pipeline = VlpnStableDiffusion(
150 text_encoder=text_encoder, 183 text_encoder=text_encoder,
151 vae=vae, 184 vae=vae,
152 unet=unet, 185 unet=unet,
@@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16):
154 scheduler=scheduler, 187 scheduler=scheduler,
155 feature_extractor=feature_extractor 188 feature_extractor=feature_extractor
156 ) 189 )
157 pipeline.enable_attention_slicing() 190 # pipeline.enable_attention_slicing()
158 pipeline.to("cuda") 191 pipeline.to("cuda")
159 192
160 print("Pipeline loaded.") 193 print("Pipeline loaded.")
@@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args):
171 204
172 save_args(output_dir, args) 205 save_args(output_dir, args)
173 206
207 if args.image:
208 init_image = Image.open(args.image)
209 if not init_image.mode == "RGB":
210 init_image = init_image.convert("RGB")
211 else:
212 init_image = None
213
174 with autocast("cuda"): 214 with autocast("cuda"):
175 for i in range(args.batch_num): 215 for i in range(args.batch_num):
176 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") 216 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}")
@@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args):
184 num_inference_steps=args.steps, 224 num_inference_steps=args.steps,
185 guidance_scale=args.guidance_scale, 225 guidance_scale=args.guidance_scale,
186 generator=generator, 226 generator=generator,
227 latents=init_image,
187 ).images 228 ).images
188 229
189 for j, image in enumerate(images): 230 for j, image in enumerate(images):
190 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 231 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
191 232
233 if torch.cuda.is_available():
234 torch.cuda.empty_cache()
235
192 236
193class CmdParse(cmd.Cmd): 237class CmdParse(cmd.Cmd):
194 prompt = 'dream> ' 238 prompt = 'dream> '
@@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd):
213 return True 257 return True
214 258
215 try: 259 try:
216 args = run_parser(self.parser, elements) 260 args = run_parser(self.parser, default_cmds, elements)
217 except SystemExit: 261 except SystemExit:
218 self.parser.print_help() 262 self.parser.print_help()
219 263
@@ -233,7 +277,7 @@ def main():
233 logging.basicConfig(stream=sys.stdout, level=logging.WARN) 277 logging.basicConfig(stream=sys.stdout, level=logging.WARN)
234 278
235 args_parser = create_args_parser() 279 args_parser = create_args_parser()
236 args = run_parser(args_parser) 280 args = run_parser(args_parser, default_args)
237 output_dir = Path(args.output_dir) 281 output_dir = Path(args.output_dir)
238 282
239 pipeline = create_pipeline(args.model, args.scheduler) 283 pipeline = create_pipeline(args.model, args.scheduler)