summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py18
-rw-r--r--dreambooth.py27
2 files changed, 16 insertions, 29 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 85ed4a5..99bcf12 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -1,3 +1,4 @@
1import math
1import os 2import os
2import pandas as pd 3import pandas as pd
3from pathlib import Path 4from pathlib import Path
@@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule):
57 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 58 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 59 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
59 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 60 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
60 center_crop=self.center_crop, repeats=self.repeats) 61 center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 62 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
62 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
63 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 63 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
64 center_crop=self.center_crop) 64 center_crop=self.center_crop, batch_size=self.batch_size)
65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
66 shuffle=True, collate_fn=self.collate_fn) 66 shuffle=True, collate_fn=self.collate_fn)
67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
@@ -85,22 +85,24 @@ class CSVDataset(Dataset):
85 interpolation="bicubic", 85 interpolation="bicubic",
86 identifier="*", 86 identifier="*",
87 center_crop=False, 87 center_crop=False,
88 batch_size=1,
88 ): 89 ):
89 90
90 self.data = data 91 self.data = data
91 self.tokenizer = tokenizer 92 self.tokenizer = tokenizer
92 self.instance_prompt = instance_prompt 93 self.instance_prompt = instance_prompt
94 self.identifier = identifier
95 self.batch_size = batch_size
96 self.cache = {}
93 97
94 self.num_instance_images = len(self.data) 98 self.num_instance_images = len(self.data)
95 self._length = self.num_instance_images * repeats 99 self._length = self.num_instance_images * repeats
96 100
97 self.identifier = identifier
98
99 if class_data_root is not None: 101 if class_data_root is not None:
100 self.class_data_root = Path(class_data_root) 102 self.class_data_root = Path(class_data_root)
101 self.class_data_root.mkdir(parents=True, exist_ok=True) 103 self.class_data_root.mkdir(parents=True, exist_ok=True)
102 104
103 self.class_images = list(Path(class_data_root).iterdir()) 105 self.class_images = list(self.class_data_root.iterdir())
104 self.num_class_images = len(self.class_images) 106 self.num_class_images = len(self.class_images)
105 self._length = max(self.num_class_images, self.num_instance_images) 107 self._length = max(self.num_class_images, self.num_instance_images)
106 108
@@ -123,10 +125,8 @@ class CSVDataset(Dataset):
123 ] 125 ]
124 ) 126 )
125 127
126 self.cache = {}
127
128 def __len__(self): 128 def __len__(self):
129 return self._length 129 return math.ceil(self._length / self.batch_size) * self.batch_size
130 130
131 def get_example(self, i): 131 def get_example(self, i):
132 image_path, text = self.data[i % self.num_instance_images] 132 image_path, text = self.data[i % self.num_instance_images]
diff --git a/dreambooth.py b/dreambooth.py
index 2df6858..0c58ab5 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -433,7 +433,7 @@ class Checkpointer:
433 del image_grid 433 del image_grid
434 del stable_latents 434 del stable_latents
435 435
436 for data, pool in [(train_data, "train"), (val_data, "val")]: 436 for data, pool in [(val_data, "val"), (train_data, "train")]:
437 all_samples = [] 437 all_samples = []
438 filename = f"step_{step}_{pool}.png" 438 filename = f"step_{step}_{pool}.png"
439 439
@@ -492,12 +492,11 @@ def main():
492 492
493 if args.with_prior_preservation: 493 if args.with_prior_preservation:
494 class_images_dir = Path(args.class_data_dir) 494 class_images_dir = Path(args.class_data_dir)
495 if not class_images_dir.exists(): 495 class_images_dir.mkdir(parents=True, exist_ok=True)
496 class_images_dir.mkdir(parents=True)
497 cur_class_images = len(list(class_images_dir.iterdir())) 496 cur_class_images = len(list(class_images_dir.iterdir()))
498 497
499 if cur_class_images < args.num_class_images: 498 if cur_class_images < args.num_class_images:
500 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 499 torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32
501 pipeline = StableDiffusionPipeline.from_pretrained( 500 pipeline = StableDiffusionPipeline.from_pretrained(
502 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 501 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
503 pipeline.enable_attention_slicing() 502 pipeline.enable_attention_slicing()
@@ -581,7 +580,6 @@ def main():
581 eps=args.adam_epsilon, 580 eps=args.adam_epsilon,
582 ) 581 )
583 582
584 # TODO (patil-suraj): laod scheduler using args
585 noise_scheduler = DDPMScheduler( 583 noise_scheduler = DDPMScheduler(
586 beta_start=0.00085, 584 beta_start=0.00085,
587 beta_end=0.012, 585 beta_end=0.012,
@@ -595,7 +593,7 @@ def main():
595 pixel_values = [example["instance_images"] for example in examples] 593 pixel_values = [example["instance_images"] for example in examples]
596 594
597 # concat class and instance examples for prior preservation 595 # concat class and instance examples for prior preservation
598 if args.with_prior_preservation: 596 if args.with_prior_preservation and "class_prompt_ids" in examples[0]:
599 input_ids += [example["class_prompt_ids"] for example in examples] 597 input_ids += [example["class_prompt_ids"] for example in examples]
600 pixel_values += [example["class_images"] for example in examples] 598 pixel_values += [example["class_images"] for example in examples]
601 599
@@ -789,6 +787,8 @@ def main():
789 787
790 train_loss /= len(train_dataloader) 788 train_loss /= len(train_dataloader)
791 789
790 accelerator.wait_for_everyone()
791
792 unet.eval() 792 unet.eval()
793 val_loss = 0.0 793 val_loss = 0.0
794 794
@@ -812,18 +812,7 @@ def main():
812 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 812 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
813 813
814 with accelerator.autocast(): 814 with accelerator.autocast():
815 if args.with_prior_preservation: 815 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
816 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
817 noise, noise_prior = torch.chunk(noise, 2, dim=0)
818
819 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
820
821 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
822 reduction="none").mean([1, 2, 3]).mean()
823
824 loss = loss + args.prior_loss_weight * prior_loss
825 else:
826 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
827 816
828 loss = loss.detach().item() 817 loss = loss.detach().item()
829 val_loss += loss 818 val_loss += loss
@@ -851,8 +840,6 @@ def main():
851 global_step, 840 global_step,
852 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 841 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
853 842
854 accelerator.wait_for_everyone()
855
856 # Create the pipeline using using the trained modules and save it. 843 # Create the pipeline using using the trained modules and save it.
857 if accelerator.is_main_process: 844 if accelerator.is_main_process:
858 print("Finished! Saving final checkpoint and resume state.") 845 print("Finished! Saving final checkpoint and resume state.")