diff options
author | Volpeon <git@volpeon.ink> | 2022-09-28 18:32:15 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-28 18:32:15 +0200 |
commit | 2a65b4eb29e4874c153a9517ab06b93481c2d238 (patch) | |
tree | aadd2a783f4b84dab4b0928a510f8625211b3e20 /dreambooth.py | |
parent | Improved sample output and progress bars (diff) | |
download | textual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.tar.gz textual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.tar.bz2 textual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.zip |
Batches of size 1 cause error: Expected query.is_contiguous() to be true, but got false
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 27 |
1 files changed, 7 insertions, 20 deletions
diff --git a/dreambooth.py b/dreambooth.py index 2df6858..0c58ab5 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -433,7 +433,7 @@ class Checkpointer: | |||
433 | del image_grid | 433 | del image_grid |
434 | del stable_latents | 434 | del stable_latents |
435 | 435 | ||
436 | for data, pool in [(train_data, "train"), (val_data, "val")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
437 | all_samples = [] | 437 | all_samples = [] |
438 | filename = f"step_{step}_{pool}.png" | 438 | filename = f"step_{step}_{pool}.png" |
439 | 439 | ||
@@ -492,12 +492,11 @@ def main(): | |||
492 | 492 | ||
493 | if args.with_prior_preservation: | 493 | if args.with_prior_preservation: |
494 | class_images_dir = Path(args.class_data_dir) | 494 | class_images_dir = Path(args.class_data_dir) |
495 | if not class_images_dir.exists(): | 495 | class_images_dir.mkdir(parents=True, exist_ok=True) |
496 | class_images_dir.mkdir(parents=True) | ||
497 | cur_class_images = len(list(class_images_dir.iterdir())) | 496 | cur_class_images = len(list(class_images_dir.iterdir())) |
498 | 497 | ||
499 | if cur_class_images < args.num_class_images: | 498 | if cur_class_images < args.num_class_images: |
500 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 | 499 | torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 |
501 | pipeline = StableDiffusionPipeline.from_pretrained( | 500 | pipeline = StableDiffusionPipeline.from_pretrained( |
502 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 501 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
503 | pipeline.enable_attention_slicing() | 502 | pipeline.enable_attention_slicing() |
@@ -581,7 +580,6 @@ def main(): | |||
581 | eps=args.adam_epsilon, | 580 | eps=args.adam_epsilon, |
582 | ) | 581 | ) |
583 | 582 | ||
584 | # TODO (patil-suraj): laod scheduler using args | ||
585 | noise_scheduler = DDPMScheduler( | 583 | noise_scheduler = DDPMScheduler( |
586 | beta_start=0.00085, | 584 | beta_start=0.00085, |
587 | beta_end=0.012, | 585 | beta_end=0.012, |
@@ -595,7 +593,7 @@ def main(): | |||
595 | pixel_values = [example["instance_images"] for example in examples] | 593 | pixel_values = [example["instance_images"] for example in examples] |
596 | 594 | ||
597 | # concat class and instance examples for prior preservation | 595 | # concat class and instance examples for prior preservation |
598 | if args.with_prior_preservation: | 596 | if args.with_prior_preservation and "class_prompt_ids" in examples[0]: |
599 | input_ids += [example["class_prompt_ids"] for example in examples] | 597 | input_ids += [example["class_prompt_ids"] for example in examples] |
600 | pixel_values += [example["class_images"] for example in examples] | 598 | pixel_values += [example["class_images"] for example in examples] |
601 | 599 | ||
@@ -789,6 +787,8 @@ def main(): | |||
789 | 787 | ||
790 | train_loss /= len(train_dataloader) | 788 | train_loss /= len(train_dataloader) |
791 | 789 | ||
790 | accelerator.wait_for_everyone() | ||
791 | |||
792 | unet.eval() | 792 | unet.eval() |
793 | val_loss = 0.0 | 793 | val_loss = 0.0 |
794 | 794 | ||
@@ -812,18 +812,7 @@ def main(): | |||
812 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 812 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
813 | 813 | ||
814 | with accelerator.autocast(): | 814 | with accelerator.autocast(): |
815 | if args.with_prior_preservation: | 815 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
816 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
817 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
818 | |||
819 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
820 | |||
821 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
822 | reduction="none").mean([1, 2, 3]).mean() | ||
823 | |||
824 | loss = loss + args.prior_loss_weight * prior_loss | ||
825 | else: | ||
826 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
827 | 816 | ||
828 | loss = loss.detach().item() | 817 | loss = loss.detach().item() |
829 | val_loss += loss | 818 | val_loss += loss |
@@ -851,8 +840,6 @@ def main(): | |||
851 | global_step, | 840 | global_step, |
852 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 841 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
853 | 842 | ||
854 | accelerator.wait_for_everyone() | ||
855 | |||
856 | # Create the pipeline using using the trained modules and save it. | 843 | # Create the pipeline using using the trained modules and save it. |
857 | if accelerator.is_main_process: | 844 | if accelerator.is_main_process: |
858 | print("Finished! Saving final checkpoint and resume state.") | 845 | print("Finished! Saving final checkpoint and resume state.") |