diff options
author | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
commit | 49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch) | |
tree | 3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /infer.py | |
parent | Fix seed, better progress bar, fix euler_a for batch size > 1 (diff) | |
download | textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2 textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip |
WIP: img2img
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 90 |
1 files changed, 67 insertions, 23 deletions
@@ -8,13 +8,47 @@ from pathlib import Path | |||
8 | from torch import autocast | 8 | from torch import autocast |
9 | import torch | 9 | import torch |
10 | import json | 10 | import json |
11 | from PIL import Image | ||
11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
13 | from slugify import slugify | 14 | from slugify import slugify |
14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
15 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
16 | 17 | ||
17 | 18 | ||
19 | default_args = { | ||
20 | "model": None, | ||
21 | "scheduler": "euler_a", | ||
22 | "output_dir": "output/inference", | ||
23 | "config": None, | ||
24 | } | ||
25 | |||
26 | |||
27 | default_cmds = { | ||
28 | "prompt": None, | ||
29 | "negative_prompt": None, | ||
30 | "image": None, | ||
31 | "image_strength": .7, | ||
32 | "width": 512, | ||
33 | "height": 512, | ||
34 | "batch_size": 1, | ||
35 | "batch_num": 1, | ||
36 | "steps": 50, | ||
37 | "guidance_scale": 7.0, | ||
38 | "seed": None, | ||
39 | "config": None, | ||
40 | } | ||
41 | |||
42 | |||
43 | def merge_dicts(d1, *args): | ||
44 | d1 = d1.copy() | ||
45 | |||
46 | for d in args: | ||
47 | d1.update({k: v for (k, v) in d.items() if v is not None}) | ||
48 | |||
49 | return d1 | ||
50 | |||
51 | |||
18 | def create_args_parser(): | 52 | def create_args_parser(): |
19 | parser = argparse.ArgumentParser( | 53 | parser = argparse.ArgumentParser( |
20 | description="Simple example of a training script." | 54 | description="Simple example of a training script." |
@@ -22,23 +56,19 @@ def create_args_parser(): | |||
22 | parser.add_argument( | 56 | parser.add_argument( |
23 | "--model", | 57 | "--model", |
24 | type=str, | 58 | type=str, |
25 | default=None, | ||
26 | ) | 59 | ) |
27 | parser.add_argument( | 60 | parser.add_argument( |
28 | "--scheduler", | 61 | "--scheduler", |
29 | type=str, | 62 | type=str, |
30 | choices=["plms", "ddim", "klms", "euler_a"], | 63 | choices=["plms", "ddim", "klms", "euler_a"], |
31 | default="euler_a", | ||
32 | ) | 64 | ) |
33 | parser.add_argument( | 65 | parser.add_argument( |
34 | "--output_dir", | 66 | "--output_dir", |
35 | type=str, | 67 | type=str, |
36 | default="output/inference", | ||
37 | ) | 68 | ) |
38 | parser.add_argument( | 69 | parser.add_argument( |
39 | "--config", | 70 | "--config", |
40 | type=str, | 71 | type=str, |
41 | default=None, | ||
42 | ) | 72 | ) |
43 | 73 | ||
44 | return parser | 74 | return parser |
@@ -51,66 +81,69 @@ def create_cmd_parser(): | |||
51 | parser.add_argument( | 81 | parser.add_argument( |
52 | "--prompt", | 82 | "--prompt", |
53 | type=str, | 83 | type=str, |
54 | default=None, | ||
55 | ) | 84 | ) |
56 | parser.add_argument( | 85 | parser.add_argument( |
57 | "--negative_prompt", | 86 | "--negative_prompt", |
58 | type=str, | 87 | type=str, |
59 | default=None, | 88 | ) |
89 | parser.add_argument( | ||
90 | "--image", | ||
91 | type=str, | ||
92 | ) | ||
93 | parser.add_argument( | ||
94 | "--image_strength", | ||
95 | type=float, | ||
60 | ) | 96 | ) |
61 | parser.add_argument( | 97 | parser.add_argument( |
62 | "--width", | 98 | "--width", |
63 | type=int, | 99 | type=int, |
64 | default=512, | ||
65 | ) | 100 | ) |
66 | parser.add_argument( | 101 | parser.add_argument( |
67 | "--height", | 102 | "--height", |
68 | type=int, | 103 | type=int, |
69 | default=512, | ||
70 | ) | 104 | ) |
71 | parser.add_argument( | 105 | parser.add_argument( |
72 | "--batch_size", | 106 | "--batch_size", |
73 | type=int, | 107 | type=int, |
74 | default=1, | ||
75 | ) | 108 | ) |
76 | parser.add_argument( | 109 | parser.add_argument( |
77 | "--batch_num", | 110 | "--batch_num", |
78 | type=int, | 111 | type=int, |
79 | default=1, | ||
80 | ) | 112 | ) |
81 | parser.add_argument( | 113 | parser.add_argument( |
82 | "--steps", | 114 | "--steps", |
83 | type=int, | 115 | type=int, |
84 | default=70, | ||
85 | ) | 116 | ) |
86 | parser.add_argument( | 117 | parser.add_argument( |
87 | "--guidance_scale", | 118 | "--guidance_scale", |
88 | type=int, | 119 | type=float, |
89 | default=7, | ||
90 | ) | 120 | ) |
91 | parser.add_argument( | 121 | parser.add_argument( |
92 | "--seed", | 122 | "--seed", |
93 | type=int, | 123 | type=int, |
94 | default=None, | ||
95 | ) | 124 | ) |
96 | parser.add_argument( | 125 | parser.add_argument( |
97 | "--config", | 126 | "--config", |
98 | type=str, | 127 | type=str, |
99 | default=None, | ||
100 | ) | 128 | ) |
101 | 129 | ||
102 | return parser | 130 | return parser |
103 | 131 | ||
104 | 132 | ||
105 | def run_parser(parser, input=None): | 133 | def run_parser(parser, defaults, input=None): |
106 | args = parser.parse_known_args(input)[0] | 134 | args = parser.parse_known_args(input)[0] |
135 | conf_args = argparse.Namespace() | ||
107 | 136 | ||
108 | if args.config is not None: | 137 | if args.config is not None: |
109 | with open(args.config, 'rt') as f: | 138 | with open(args.config, 'rt') as f: |
110 | args = parser.parse_known_args( | 139 | conf_args = parser.parse_known_args( |
111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | 140 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
112 | 141 | ||
113 | return args | 142 | res = defaults.copy() |
143 | for dict in [vars(conf_args), vars(args)]: | ||
144 | res.update({k: v for (k, v) in dict.items() if v is not None}) | ||
145 | |||
146 | return argparse.Namespace(**res) | ||
114 | 147 | ||
115 | 148 | ||
116 | def save_args(basepath, args, extra={}): | 149 | def save_args(basepath, args, extra={}): |
@@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
146 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
147 | ) | 180 | ) |
148 | 181 | ||
149 | pipeline = CLIPGuidedStableDiffusion( | 182 | pipeline = VlpnStableDiffusion( |
150 | text_encoder=text_encoder, | 183 | text_encoder=text_encoder, |
151 | vae=vae, | 184 | vae=vae, |
152 | unet=unet, | 185 | unet=unet, |
@@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
154 | scheduler=scheduler, | 187 | scheduler=scheduler, |
155 | feature_extractor=feature_extractor | 188 | feature_extractor=feature_extractor |
156 | ) | 189 | ) |
157 | pipeline.enable_attention_slicing() | 190 | # pipeline.enable_attention_slicing() |
158 | pipeline.to("cuda") | 191 | pipeline.to("cuda") |
159 | 192 | ||
160 | print("Pipeline loaded.") | 193 | print("Pipeline loaded.") |
@@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args): | |||
171 | 204 | ||
172 | save_args(output_dir, args) | 205 | save_args(output_dir, args) |
173 | 206 | ||
207 | if args.image: | ||
208 | init_image = Image.open(args.image) | ||
209 | if not init_image.mode == "RGB": | ||
210 | init_image = init_image.convert("RGB") | ||
211 | else: | ||
212 | init_image = None | ||
213 | |||
174 | with autocast("cuda"): | 214 | with autocast("cuda"): |
175 | for i in range(args.batch_num): | 215 | for i in range(args.batch_num): |
176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") | 216 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") |
@@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args): | |||
184 | num_inference_steps=args.steps, | 224 | num_inference_steps=args.steps, |
185 | guidance_scale=args.guidance_scale, | 225 | guidance_scale=args.guidance_scale, |
186 | generator=generator, | 226 | generator=generator, |
227 | latents=init_image, | ||
187 | ).images | 228 | ).images |
188 | 229 | ||
189 | for j, image in enumerate(images): | 230 | for j, image in enumerate(images): |
190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 231 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
191 | 232 | ||
233 | if torch.cuda.is_available(): | ||
234 | torch.cuda.empty_cache() | ||
235 | |||
192 | 236 | ||
193 | class CmdParse(cmd.Cmd): | 237 | class CmdParse(cmd.Cmd): |
194 | prompt = 'dream> ' | 238 | prompt = 'dream> ' |
@@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd): | |||
213 | return True | 257 | return True |
214 | 258 | ||
215 | try: | 259 | try: |
216 | args = run_parser(self.parser, elements) | 260 | args = run_parser(self.parser, default_cmds, elements) |
217 | except SystemExit: | 261 | except SystemExit: |
218 | self.parser.print_help() | 262 | self.parser.print_help() |
219 | 263 | ||
@@ -233,7 +277,7 @@ def main(): | |||
233 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | 277 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) |
234 | 278 | ||
235 | args_parser = create_args_parser() | 279 | args_parser = create_args_parser() |
236 | args = run_parser(args_parser) | 280 | args = run_parser(args_parser, default_args) |
237 | output_dir = Path(args.output_dir) | 281 | output_dir = Path(args.output_dir) |
238 | 282 | ||
239 | pipeline = create_pipeline(args.model, args.scheduler) | 283 | pipeline = create_pipeline(args.model, args.scheduler) |