summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-28 18:32:15 +0200
committerVolpeon <git@volpeon.ink>2022-09-28 18:32:15 +0200
commit2a65b4eb29e4874c153a9517ab06b93481c2d238 (patch)
treeaadd2a783f4b84dab4b0928a510f8625211b3e20 /dreambooth.py
parentImproved sample output and progress bars (diff)
downloadtextual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.tar.gz
textual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.tar.bz2
textual-inversion-diff-2a65b4eb29e4874c153a9517ab06b93481c2d238.zip
Batches of size 1 cause error: Expected query.is_contiguous() to be true, but got false
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py27
1 files changed, 7 insertions, 20 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2df6858..0c58ab5 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -433,7 +433,7 @@ class Checkpointer:
433 del image_grid 433 del image_grid
434 del stable_latents 434 del stable_latents
435 435
436 for data, pool in [(train_data, "train"), (val_data, "val")]: 436 for data, pool in [(val_data, "val"), (train_data, "train")]:
437 all_samples = [] 437 all_samples = []
438 filename = f"step_{step}_{pool}.png" 438 filename = f"step_{step}_{pool}.png"
439 439
@@ -492,12 +492,11 @@ def main():
492 492
493 if args.with_prior_preservation: 493 if args.with_prior_preservation:
494 class_images_dir = Path(args.class_data_dir) 494 class_images_dir = Path(args.class_data_dir)
495 if not class_images_dir.exists(): 495 class_images_dir.mkdir(parents=True, exist_ok=True)
496 class_images_dir.mkdir(parents=True)
497 cur_class_images = len(list(class_images_dir.iterdir())) 496 cur_class_images = len(list(class_images_dir.iterdir()))
498 497
499 if cur_class_images < args.num_class_images: 498 if cur_class_images < args.num_class_images:
500 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 499 torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32
501 pipeline = StableDiffusionPipeline.from_pretrained( 500 pipeline = StableDiffusionPipeline.from_pretrained(
502 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 501 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
503 pipeline.enable_attention_slicing() 502 pipeline.enable_attention_slicing()
@@ -581,7 +580,6 @@ def main():
581 eps=args.adam_epsilon, 580 eps=args.adam_epsilon,
582 ) 581 )
583 582
584 # TODO (patil-suraj): laod scheduler using args
585 noise_scheduler = DDPMScheduler( 583 noise_scheduler = DDPMScheduler(
586 beta_start=0.00085, 584 beta_start=0.00085,
587 beta_end=0.012, 585 beta_end=0.012,
@@ -595,7 +593,7 @@ def main():
595 pixel_values = [example["instance_images"] for example in examples] 593 pixel_values = [example["instance_images"] for example in examples]
596 594
597 # concat class and instance examples for prior preservation 595 # concat class and instance examples for prior preservation
598 if args.with_prior_preservation: 596 if args.with_prior_preservation and "class_prompt_ids" in examples[0]:
599 input_ids += [example["class_prompt_ids"] for example in examples] 597 input_ids += [example["class_prompt_ids"] for example in examples]
600 pixel_values += [example["class_images"] for example in examples] 598 pixel_values += [example["class_images"] for example in examples]
601 599
@@ -789,6 +787,8 @@ def main():
789 787
790 train_loss /= len(train_dataloader) 788 train_loss /= len(train_dataloader)
791 789
790 accelerator.wait_for_everyone()
791
792 unet.eval() 792 unet.eval()
793 val_loss = 0.0 793 val_loss = 0.0
794 794
@@ -812,18 +812,7 @@ def main():
812 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 812 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
813 813
814 with accelerator.autocast(): 814 with accelerator.autocast():
815 if args.with_prior_preservation: 815 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
816 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
817 noise, noise_prior = torch.chunk(noise, 2, dim=0)
818
819 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
820
821 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
822 reduction="none").mean([1, 2, 3]).mean()
823
824 loss = loss + args.prior_loss_weight * prior_loss
825 else:
826 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
827 816
828 loss = loss.detach().item() 817 loss = loss.detach().item()
829 val_loss += loss 818 val_loss += loss
@@ -851,8 +840,6 @@ def main():
851 global_step, 840 global_step,
852 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 841 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
853 842
854 accelerator.wait_for_everyone()
855
856 # Create the pipeline using using the trained modules and save it. 843 # Create the pipeline using using the trained modules and save it.
857 if accelerator.is_main_process: 844 if accelerator.is_main_process:
858 print("Finished! Saving final checkpoint and resume state.") 845 print("Finished! Saving final checkpoint and resume state.")