diff options
author | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
commit | b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch) | |
tree | 2ad3740868696fc071d8850171e6e53ccc3a7bd2 /infer.py | |
parent | Update (diff) | |
download | textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2 textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip |
Update
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 52 |
1 files changed, 23 insertions, 29 deletions
@@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
20 | 20 | ||
21 | default_args = { | 21 | default_args = { |
22 | "model": None, | 22 | "model": None, |
23 | "scheduler": "dpmpp", | ||
24 | "precision": "fp32", | 23 | "precision": "fp32", |
25 | "ti_embeddings_dir": "embeddings_ti", | 24 | "ti_embeddings_dir": "embeddings_ti", |
26 | "output_dir": "output/inference", | 25 | "output_dir": "output/inference", |
@@ -29,6 +28,7 @@ default_args = { | |||
29 | 28 | ||
30 | 29 | ||
31 | default_cmds = { | 30 | default_cmds = { |
31 | "scheduler": "dpmpp", | ||
32 | "prompt": None, | 32 | "prompt": None, |
33 | "negative_prompt": None, | 33 | "negative_prompt": None, |
34 | "image": None, | 34 | "image": None, |
@@ -62,11 +62,6 @@ def create_args_parser(): | |||
62 | type=str, | 62 | type=str, |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--scheduler", | ||
66 | type=str, | ||
67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
68 | ) | ||
69 | parser.add_argument( | ||
70 | "--precision", | 65 | "--precision", |
71 | type=str, | 66 | type=str, |
72 | choices=["fp32", "fp16", "bf16"], | 67 | choices=["fp32", "fp16", "bf16"], |
@@ -92,6 +87,11 @@ def create_cmd_parser(): | |||
92 | description="Simple example of a training script." | 87 | description="Simple example of a training script." |
93 | ) | 88 | ) |
94 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--scheduler", | ||
91 | type=str, | ||
92 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
93 | ) | ||
94 | parser.add_argument( | ||
95 | "--prompt", | 95 | "--prompt", |
96 | type=str, | 96 | type=str, |
97 | nargs="+", | 97 | nargs="+", |
@@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
199 | print(f"Loaded {placeholder_token}") | 199 | print(f"Loaded {placeholder_token}") |
200 | 200 | ||
201 | 201 | ||
202 | def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | 202 | def create_pipeline(model, ti_embeddings_dir, dtype): |
203 | print("Loading Stable Diffusion pipeline...") | 203 | print("Loading Stable Diffusion pipeline...") |
204 | 204 | ||
205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
209 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
209 | 210 | ||
210 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 211 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) |
211 | 212 | ||
212 | if scheduler == "plms": | ||
213 | scheduler = PNDMScheduler( | ||
214 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
215 | ) | ||
216 | elif scheduler == "klms": | ||
217 | scheduler = LMSDiscreteScheduler( | ||
218 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
219 | ) | ||
220 | elif scheduler == "ddim": | ||
221 | scheduler = DDIMScheduler( | ||
222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
223 | ) | ||
224 | elif scheduler == "dpmpp": | ||
225 | scheduler = DPMSolverMultistepScheduler( | ||
226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
227 | ) | ||
228 | else: | ||
229 | scheduler = EulerAncestralDiscreteScheduler( | ||
230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
231 | ) | ||
232 | |||
233 | pipeline = VlpnStableDiffusion( | 213 | pipeline = VlpnStableDiffusion( |
234 | text_encoder=text_encoder, | 214 | text_encoder=text_encoder, |
235 | vae=vae, | 215 | vae=vae, |
@@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args): | |||
264 | else: | 244 | else: |
265 | init_image = None | 245 | init_image = None |
266 | 246 | ||
247 | if args.scheduler == "plms": | ||
248 | pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) | ||
249 | elif args.scheduler == "klms": | ||
250 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
251 | elif args.scheduler == "ddim": | ||
252 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||
253 | elif args.scheduler == "dpmpp": | ||
254 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
255 | elif args.scheduler == "euler_a": | ||
256 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
257 | |||
267 | with torch.autocast("cuda"), torch.inference_mode(): | 258 | with torch.autocast("cuda"), torch.inference_mode(): |
268 | for i in range(args.batch_num): | 259 | for i in range(args.batch_num): |
269 | pipeline.set_progress_bar_config( | 260 | pipeline.set_progress_bar_config( |
@@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd): | |||
331 | generate(self.output_dir, self.pipeline, args) | 322 | generate(self.output_dir, self.pipeline, args) |
332 | except KeyboardInterrupt: | 323 | except KeyboardInterrupt: |
333 | print('Generation cancelled.') | 324 | print('Generation cancelled.') |
325 | except Exception as e: | ||
326 | print(e) | ||
327 | return | ||
334 | 328 | ||
335 | def do_exit(self, line): | 329 | def do_exit(self, line): |
336 | return True | 330 | return True |
@@ -345,7 +339,7 @@ def main(): | |||
345 | output_dir = Path(args.output_dir) | 339 | output_dir = Path(args.output_dir) |
346 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 340 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
347 | 341 | ||
348 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) | 342 | pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) |
349 | cmd_parser = create_cmd_parser() | 343 | cmd_parser = create_cmd_parser() |
350 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 344 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
351 | cmd_prompt.cmdloop() | 345 | cmd_prompt.cmdloop() |