diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
| commit | b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch) | |
| tree | 2ad3740868696fc071d8850171e6e53ccc3a7bd2 /infer.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2 textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip | |
Update
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 52 |
1 files changed, 23 insertions, 29 deletions
| @@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 20 | 20 | ||
| 21 | default_args = { | 21 | default_args = { |
| 22 | "model": None, | 22 | "model": None, |
| 23 | "scheduler": "dpmpp", | ||
| 24 | "precision": "fp32", | 23 | "precision": "fp32", |
| 25 | "ti_embeddings_dir": "embeddings_ti", | 24 | "ti_embeddings_dir": "embeddings_ti", |
| 26 | "output_dir": "output/inference", | 25 | "output_dir": "output/inference", |
| @@ -29,6 +28,7 @@ default_args = { | |||
| 29 | 28 | ||
| 30 | 29 | ||
| 31 | default_cmds = { | 30 | default_cmds = { |
| 31 | "scheduler": "dpmpp", | ||
| 32 | "prompt": None, | 32 | "prompt": None, |
| 33 | "negative_prompt": None, | 33 | "negative_prompt": None, |
| 34 | "image": None, | 34 | "image": None, |
| @@ -62,11 +62,6 @@ def create_args_parser(): | |||
| 62 | type=str, | 62 | type=str, |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--scheduler", | ||
| 66 | type=str, | ||
| 67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
| 68 | ) | ||
| 69 | parser.add_argument( | ||
| 70 | "--precision", | 65 | "--precision", |
| 71 | type=str, | 66 | type=str, |
| 72 | choices=["fp32", "fp16", "bf16"], | 67 | choices=["fp32", "fp16", "bf16"], |
| @@ -92,6 +87,11 @@ def create_cmd_parser(): | |||
| 92 | description="Simple example of a training script." | 87 | description="Simple example of a training script." |
| 93 | ) | 88 | ) |
| 94 | parser.add_argument( | 89 | parser.add_argument( |
| 90 | "--scheduler", | ||
| 91 | type=str, | ||
| 92 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | ||
| 93 | ) | ||
| 94 | parser.add_argument( | ||
| 95 | "--prompt", | 95 | "--prompt", |
| 96 | type=str, | 96 | type=str, |
| 97 | nargs="+", | 97 | nargs="+", |
| @@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
| 199 | print(f"Loaded {placeholder_token}") | 199 | print(f"Loaded {placeholder_token}") |
| 200 | 200 | ||
| 201 | 201 | ||
| 202 | def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | 202 | def create_pipeline(model, ti_embeddings_dir, dtype): |
| 203 | print("Loading Stable Diffusion pipeline...") | 203 | print("Loading Stable Diffusion pipeline...") |
| 204 | 204 | ||
| 205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 205 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| 206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 206 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
| 207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 207 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
| 208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 208 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 209 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
| 209 | 210 | ||
| 210 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 211 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) |
| 211 | 212 | ||
| 212 | if scheduler == "plms": | ||
| 213 | scheduler = PNDMScheduler( | ||
| 214 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
| 215 | ) | ||
| 216 | elif scheduler == "klms": | ||
| 217 | scheduler = LMSDiscreteScheduler( | ||
| 218 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 219 | ) | ||
| 220 | elif scheduler == "ddim": | ||
| 221 | scheduler = DDIMScheduler( | ||
| 222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
| 223 | ) | ||
| 224 | elif scheduler == "dpmpp": | ||
| 225 | scheduler = DPMSolverMultistepScheduler( | ||
| 226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 227 | ) | ||
| 228 | else: | ||
| 229 | scheduler = EulerAncestralDiscreteScheduler( | ||
| 230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 231 | ) | ||
| 232 | |||
| 233 | pipeline = VlpnStableDiffusion( | 213 | pipeline = VlpnStableDiffusion( |
| 234 | text_encoder=text_encoder, | 214 | text_encoder=text_encoder, |
| 235 | vae=vae, | 215 | vae=vae, |
| @@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args): | |||
| 264 | else: | 244 | else: |
| 265 | init_image = None | 245 | init_image = None |
| 266 | 246 | ||
| 247 | if args.scheduler == "plms": | ||
| 248 | pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) | ||
| 249 | elif args.scheduler == "klms": | ||
| 250 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 251 | elif args.scheduler == "ddim": | ||
| 252 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||
| 253 | elif args.scheduler == "dpmpp": | ||
| 254 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
| 255 | elif args.scheduler == "euler_a": | ||
| 256 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 257 | |||
| 267 | with torch.autocast("cuda"), torch.inference_mode(): | 258 | with torch.autocast("cuda"), torch.inference_mode(): |
| 268 | for i in range(args.batch_num): | 259 | for i in range(args.batch_num): |
| 269 | pipeline.set_progress_bar_config( | 260 | pipeline.set_progress_bar_config( |
| @@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd): | |||
| 331 | generate(self.output_dir, self.pipeline, args) | 322 | generate(self.output_dir, self.pipeline, args) |
| 332 | except KeyboardInterrupt: | 323 | except KeyboardInterrupt: |
| 333 | print('Generation cancelled.') | 324 | print('Generation cancelled.') |
| 325 | except Exception as e: | ||
| 326 | print(e) | ||
| 327 | return | ||
| 334 | 328 | ||
| 335 | def do_exit(self, line): | 329 | def do_exit(self, line): |
| 336 | return True | 330 | return True |
| @@ -345,7 +339,7 @@ def main(): | |||
| 345 | output_dir = Path(args.output_dir) | 339 | output_dir = Path(args.output_dir) |
| 346 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 340 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 347 | 341 | ||
| 348 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) | 342 | pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) |
| 349 | cmd_parser = create_cmd_parser() | 343 | cmd_parser = create_cmd_parser() |
| 350 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 344 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 351 | cmd_prompt.cmdloop() | 345 | cmd_prompt.cmdloop() |
