diff options
| -rw-r--r-- | data/dreambooth/csv.py | 18 | ||||
| -rw-r--r-- | dreambooth.py | 27 |
2 files changed, 16 insertions, 29 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 85ed4a5..99bcf12 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -1,3 +1,4 @@ | |||
| 1 | import math | ||
| 1 | import os | 2 | import os |
| 2 | import pandas as pd | 3 | import pandas as pd |
| 3 | from pathlib import Path | 4 | from pathlib import Path |
| @@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 58 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
| 58 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 59 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
| 59 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 60 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 60 | center_crop=self.center_crop, repeats=self.repeats) | 61 | center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) |
| 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 62 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, |
| 62 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | ||
| 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 64 | center_crop=self.center_crop) | 64 | center_crop=self.center_crop, batch_size=self.batch_size) |
| 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 66 | shuffle=True, collate_fn=self.collate_fn) | 66 | shuffle=True, collate_fn=self.collate_fn) |
| 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) |
| @@ -85,22 +85,24 @@ class CSVDataset(Dataset): | |||
| 85 | interpolation="bicubic", | 85 | interpolation="bicubic", |
| 86 | identifier="*", | 86 | identifier="*", |
| 87 | center_crop=False, | 87 | center_crop=False, |
| 88 | batch_size=1, | ||
| 88 | ): | 89 | ): |
| 89 | 90 | ||
| 90 | self.data = data | 91 | self.data = data |
| 91 | self.tokenizer = tokenizer | 92 | self.tokenizer = tokenizer |
| 92 | self.instance_prompt = instance_prompt | 93 | self.instance_prompt = instance_prompt |
| 94 | self.identifier = identifier | ||
| 95 | self.batch_size = batch_size | ||
| 96 | self.cache = {} | ||
| 93 | 97 | ||
| 94 | self.num_instance_images = len(self.data) | 98 | self.num_instance_images = len(self.data) |
| 95 | self._length = self.num_instance_images * repeats | 99 | self._length = self.num_instance_images * repeats |
| 96 | 100 | ||
| 97 | self.identifier = identifier | ||
| 98 | |||
| 99 | if class_data_root is not None: | 101 | if class_data_root is not None: |
| 100 | self.class_data_root = Path(class_data_root) | 102 | self.class_data_root = Path(class_data_root) |
| 101 | self.class_data_root.mkdir(parents=True, exist_ok=True) | 103 | self.class_data_root.mkdir(parents=True, exist_ok=True) |
| 102 | 104 | ||
| 103 | self.class_images = list(Path(class_data_root).iterdir()) | 105 | self.class_images = list(self.class_data_root.iterdir()) |
| 104 | self.num_class_images = len(self.class_images) | 106 | self.num_class_images = len(self.class_images) |
| 105 | self._length = max(self.num_class_images, self.num_instance_images) | 107 | self._length = max(self.num_class_images, self.num_instance_images) |
| 106 | 108 | ||
| @@ -123,10 +125,8 @@ class CSVDataset(Dataset): | |||
| 123 | ] | 125 | ] |
| 124 | ) | 126 | ) |
| 125 | 127 | ||
| 126 | self.cache = {} | ||
| 127 | |||
| 128 | def __len__(self): | 128 | def __len__(self): |
| 129 | return self._length | 129 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 130 | 130 | ||
| 131 | def get_example(self, i): | 131 | def get_example(self, i): |
| 132 | image_path, text = self.data[i % self.num_instance_images] | 132 | image_path, text = self.data[i % self.num_instance_images] |
diff --git a/dreambooth.py b/dreambooth.py index 2df6858..0c58ab5 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -433,7 +433,7 @@ class Checkpointer: | |||
| 433 | del image_grid | 433 | del image_grid |
| 434 | del stable_latents | 434 | del stable_latents |
| 435 | 435 | ||
| 436 | for data, pool in [(train_data, "train"), (val_data, "val")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
| 437 | all_samples = [] | 437 | all_samples = [] |
| 438 | filename = f"step_{step}_{pool}.png" | 438 | filename = f"step_{step}_{pool}.png" |
| 439 | 439 | ||
| @@ -492,12 +492,11 @@ def main(): | |||
| 492 | 492 | ||
| 493 | if args.with_prior_preservation: | 493 | if args.with_prior_preservation: |
| 494 | class_images_dir = Path(args.class_data_dir) | 494 | class_images_dir = Path(args.class_data_dir) |
| 495 | if not class_images_dir.exists(): | 495 | class_images_dir.mkdir(parents=True, exist_ok=True) |
| 496 | class_images_dir.mkdir(parents=True) | ||
| 497 | cur_class_images = len(list(class_images_dir.iterdir())) | 496 | cur_class_images = len(list(class_images_dir.iterdir())) |
| 498 | 497 | ||
| 499 | if cur_class_images < args.num_class_images: | 498 | if cur_class_images < args.num_class_images: |
| 500 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 | 499 | torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 |
| 501 | pipeline = StableDiffusionPipeline.from_pretrained( | 500 | pipeline = StableDiffusionPipeline.from_pretrained( |
| 502 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 501 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
| 503 | pipeline.enable_attention_slicing() | 502 | pipeline.enable_attention_slicing() |
| @@ -581,7 +580,6 @@ def main(): | |||
| 581 | eps=args.adam_epsilon, | 580 | eps=args.adam_epsilon, |
| 582 | ) | 581 | ) |
| 583 | 582 | ||
| 584 | # TODO (patil-suraj): laod scheduler using args | ||
| 585 | noise_scheduler = DDPMScheduler( | 583 | noise_scheduler = DDPMScheduler( |
| 586 | beta_start=0.00085, | 584 | beta_start=0.00085, |
| 587 | beta_end=0.012, | 585 | beta_end=0.012, |
| @@ -595,7 +593,7 @@ def main(): | |||
| 595 | pixel_values = [example["instance_images"] for example in examples] | 593 | pixel_values = [example["instance_images"] for example in examples] |
| 596 | 594 | ||
| 597 | # concat class and instance examples for prior preservation | 595 | # concat class and instance examples for prior preservation |
| 598 | if args.with_prior_preservation: | 596 | if args.with_prior_preservation and "class_prompt_ids" in examples[0]: |
| 599 | input_ids += [example["class_prompt_ids"] for example in examples] | 597 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 600 | pixel_values += [example["class_images"] for example in examples] | 598 | pixel_values += [example["class_images"] for example in examples] |
| 601 | 599 | ||
| @@ -789,6 +787,8 @@ def main(): | |||
| 789 | 787 | ||
| 790 | train_loss /= len(train_dataloader) | 788 | train_loss /= len(train_dataloader) |
| 791 | 789 | ||
| 790 | accelerator.wait_for_everyone() | ||
| 791 | |||
| 792 | unet.eval() | 792 | unet.eval() |
| 793 | val_loss = 0.0 | 793 | val_loss = 0.0 |
| 794 | 794 | ||
| @@ -812,18 +812,7 @@ def main(): | |||
| 812 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 812 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 813 | 813 | ||
| 814 | with accelerator.autocast(): | 814 | with accelerator.autocast(): |
| 815 | if args.with_prior_preservation: | 815 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 816 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 817 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 818 | |||
| 819 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 820 | |||
| 821 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
| 822 | reduction="none").mean([1, 2, 3]).mean() | ||
| 823 | |||
| 824 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 825 | else: | ||
| 826 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 827 | 816 | ||
| 828 | loss = loss.detach().item() | 817 | loss = loss.detach().item() |
| 829 | val_loss += loss | 818 | val_loss += loss |
| @@ -851,8 +840,6 @@ def main(): | |||
| 851 | global_step, | 840 | global_step, |
| 852 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 841 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 853 | 842 | ||
| 854 | accelerator.wait_for_everyone() | ||
| 855 | |||
| 856 | # Create the pipeline using using the trained modules and save it. | 843 | # Create the pipeline using using the trained modules and save it. |
| 857 | if accelerator.is_main_process: | 844 | if accelerator.is_main_process: |
| 858 | print("Finished! Saving final checkpoint and resume state.") | 845 | print("Finished! Saving final checkpoint and resume state.") |
