diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
| commit | c90099f06e0b461660b326fb6d86b69d86e78289 (patch) | |
| tree | df4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /textual_inversion.py | |
| parent | Fixed euler_a generator argument (diff) | |
| download | textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2 textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip | |
Added negative prompt support for training scripts
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 83 |
1 files changed, 23 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -207,16 +207,10 @@ def parse_args(): | |||
| 207 | help="Size of sample images", | 207 | help="Size of sample images", |
| 208 | ) | 208 | ) |
| 209 | parser.add_argument( | 209 | parser.add_argument( |
| 210 | "--stable_sample_batches", | 210 | "--sample_batches", |
| 211 | type=int, | 211 | type=int, |
| 212 | default=1, | 212 | default=1, |
| 213 | help="Number of fixed seed sample batches to generate per checkpoint", | 213 | help="Number of sample batches to generate per checkpoint", |
| 214 | ) | ||
| 215 | parser.add_argument( | ||
| 216 | "--random_sample_batches", | ||
| 217 | type=int, | ||
| 218 | default=1, | ||
| 219 | help="Number of random seed sample batches to generate per checkpoint", | ||
| 220 | ) | 214 | ) |
| 221 | parser.add_argument( | 215 | parser.add_argument( |
| 222 | "--sample_batch_size", | 216 | "--sample_batch_size", |
| @@ -319,9 +313,8 @@ class Checkpointer: | |||
| 319 | placeholder_token_id, | 313 | placeholder_token_id, |
| 320 | output_dir, | 314 | output_dir, |
| 321 | sample_image_size, | 315 | sample_image_size, |
| 322 | random_sample_batches, | 316 | sample_batches, |
| 323 | sample_batch_size, | 317 | sample_batch_size, |
| 324 | stable_sample_batches, | ||
| 325 | seed | 318 | seed |
| 326 | ): | 319 | ): |
| 327 | self.datamodule = datamodule | 320 | self.datamodule = datamodule |
| @@ -334,9 +327,8 @@ class Checkpointer: | |||
| 334 | self.output_dir = output_dir | 327 | self.output_dir = output_dir |
| 335 | self.sample_image_size = sample_image_size | 328 | self.sample_image_size = sample_image_size |
| 336 | self.seed = seed | 329 | self.seed = seed |
| 337 | self.random_sample_batches = random_sample_batches | 330 | self.sample_batches = sample_batches |
| 338 | self.sample_batch_size = sample_batch_size | 331 | self.sample_batch_size = sample_batch_size |
| 339 | self.stable_sample_batches = stable_sample_batches | ||
| 340 | 332 | ||
| 341 | @torch.no_grad() | 333 | @torch.no_grad() |
| 342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
| @@ -385,63 +377,33 @@ class Checkpointer: | |||
| 385 | train_data = self.datamodule.train_dataloader() | 377 | train_data = self.datamodule.train_dataloader() |
| 386 | val_data = self.datamodule.val_dataloader() | 378 | val_data = self.datamodule.val_dataloader() |
| 387 | 379 | ||
| 388 | if self.stable_sample_batches > 0: | 380 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 389 | stable_latents = torch.randn( | 381 | stable_latents = torch.randn( |
| 390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 382 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 391 | device=pipeline.device, | 383 | device=pipeline.device, |
| 392 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 384 | generator=generator, |
| 393 | ) | 385 | ) |
| 394 | |||
| 395 | all_samples = [] | ||
| 396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
| 397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 398 | |||
| 399 | data_enum = enumerate(val_data) | ||
| 400 | |||
| 401 | # Generate and save stable samples | ||
| 402 | for i in range(0, self.stable_sample_batches): | ||
| 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
| 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
| 405 | |||
| 406 | with self.accelerator.autocast(): | ||
| 407 | samples = pipeline( | ||
| 408 | prompt=prompt, | ||
| 409 | height=self.sample_image_size, | ||
| 410 | latents=stable_latents[:len(prompt)], | ||
| 411 | width=self.sample_image_size, | ||
| 412 | guidance_scale=guidance_scale, | ||
| 413 | eta=eta, | ||
| 414 | num_inference_steps=num_inference_steps, | ||
| 415 | output_type='pil' | ||
| 416 | )["sample"] | ||
| 417 | |||
| 418 | all_samples += samples | ||
| 419 | |||
| 420 | del samples | ||
| 421 | |||
| 422 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
| 423 | image_grid.save(file_path) | ||
| 424 | |||
| 425 | del all_samples | ||
| 426 | del image_grid | ||
| 427 | del stable_latents | ||
| 428 | 386 | ||
| 429 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 387 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 430 | all_samples = [] | 388 | all_samples = [] |
| 431 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 389 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 432 | file_path.parent.mkdir(parents=True, exist_ok=True) | 390 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 433 | 391 | ||
| 434 | data_enum = enumerate(data) | 392 | data_enum = enumerate(data) |
| 435 | 393 | ||
| 436 | for i in range(0, self.random_sample_batches): | 394 | for i in range(self.sample_batches): |
| 437 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 395 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 438 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 396 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 397 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
| 439 | 398 | ||
| 440 | with self.accelerator.autocast(): | 399 | with self.accelerator.autocast(): |
| 441 | samples = pipeline( | 400 | samples = pipeline( |
| 442 | prompt=prompt, | 401 | prompt=prompt, |
| 402 | negative_prompt=nprompt, | ||
| 443 | height=self.sample_image_size, | 403 | height=self.sample_image_size, |
| 444 | width=self.sample_image_size, | 404 | width=self.sample_image_size, |
| 405 | latents=latents[:len(prompt)] if latents is not None else None, | ||
| 406 | generator=generator if latents is not None else None, | ||
| 445 | guidance_scale=guidance_scale, | 407 | guidance_scale=guidance_scale, |
| 446 | eta=eta, | 408 | eta=eta, |
| 447 | num_inference_steps=num_inference_steps, | 409 | num_inference_steps=num_inference_steps, |
| @@ -452,7 +414,7 @@ class Checkpointer: | |||
| 452 | 414 | ||
| 453 | del samples | 415 | del samples |
| 454 | 416 | ||
| 455 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 417 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 456 | image_grid.save(file_path) | 418 | image_grid.save(file_path) |
| 457 | 419 | ||
| 458 | del all_samples | 420 | del all_samples |
| @@ -461,6 +423,8 @@ class Checkpointer: | |||
| 461 | del unwrapped | 423 | del unwrapped |
| 462 | del scheduler | 424 | del scheduler |
| 463 | del pipeline | 425 | del pipeline |
| 426 | del generator | ||
| 427 | del stable_latents | ||
| 464 | 428 | ||
| 465 | if torch.cuda.is_available(): | 429 | if torch.cuda.is_available(): |
| 466 | torch.cuda.empty_cache() | 430 | torch.cuda.empty_cache() |
| @@ -603,7 +567,7 @@ def main(): | |||
| 603 | placeholder_token=args.placeholder_token, | 567 | placeholder_token=args.placeholder_token, |
| 604 | repeats=args.repeats, | 568 | repeats=args.repeats, |
| 605 | center_crop=args.center_crop, | 569 | center_crop=args.center_crop, |
| 606 | valid_set_size=args.sample_batch_size*args.stable_sample_batches | 570 | valid_set_size=args.sample_batch_size*args.sample_batches |
| 607 | ) | 571 | ) |
| 608 | 572 | ||
| 609 | datamodule.prepare_data() | 573 | datamodule.prepare_data() |
| @@ -623,8 +587,7 @@ def main(): | |||
| 623 | output_dir=basepath, | 587 | output_dir=basepath, |
| 624 | sample_image_size=args.sample_image_size, | 588 | sample_image_size=args.sample_image_size, |
| 625 | sample_batch_size=args.sample_batch_size, | 589 | sample_batch_size=args.sample_batch_size, |
| 626 | random_sample_batches=args.random_sample_batches, | 590 | sample_batches=args.sample_batches, |
| 627 | stable_sample_batches=args.stable_sample_batches, | ||
| 628 | seed=args.seed or torch.random.seed() | 591 | seed=args.seed or torch.random.seed() |
| 629 | ) | 592 | ) |
| 630 | 593 | ||
