From 2a65b4eb29e4874c153a9517ab06b93481c2d238 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 18:32:15 +0200 Subject: Batches of size 1 cause error: Expected query.is_contiguous() to be true, but got false --- data/dreambooth/csv.py | 18 +++++++++--------- dreambooth.py | 27 +++++++-------------------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 85ed4a5..99bcf12 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -1,3 +1,4 @@ +import math import os import pandas as pd from pathlib import Path @@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule): train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop) + center_crop=self.center_crop, batch_size=self.batch_size) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) @@ -85,22 +85,24 @@ class CSVDataset(Dataset): interpolation="bicubic", identifier="*", center_crop=False, + batch_size=1, ): self.data = data self.tokenizer = tokenizer self.instance_prompt = instance_prompt + self.identifier = identifier + self.batch_size = batch_size + self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - self.identifier = identifier - if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images = list(Path(class_data_root).iterdir()) + self.class_images = list(self.class_data_root.iterdir()) self.num_class_images = len(self.class_images) self._length = max(self.num_class_images, self.num_instance_images) @@ -123,10 +125,8 @@ class CSVDataset(Dataset): ] ) - self.cache = {} - def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): image_path, text = self.data[i % self.num_instance_images] diff --git a/dreambooth.py b/dreambooth.py index 2df6858..0c58ab5 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -433,7 +433,7 @@ class Checkpointer: del image_grid del stable_latents - for data, pool in [(train_data, "train"), (val_data, "val")]: + for data, pool in [(val_data, "val"), (train_data, "train")]: all_samples = [] filename = f"step_{step}_{pool}.png" @@ -492,12 +492,11 @@ def main(): if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) + class_images_dir.mkdir(parents=True, exist_ok=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) pipeline.enable_attention_slicing() @@ -581,7 +580,6 @@ def main(): eps=args.adam_epsilon, ) - # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, @@ -595,7 +593,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.with_prior_preservation: + if args.with_prior_preservation and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -789,6 +787,8 @@ def main(): train_loss /= len(train_dataloader) + accelerator.wait_for_everyone() + unet.eval() val_loss = 0.0 @@ -812,18 +812,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) with accelerator.autocast(): - if args.with_prior_preservation: - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) - - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() - - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss @@ -851,8 +840,6 @@ def main(): global_step, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - accelerator.wait_for_everyone() - # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") -- cgit v1.2.3-70-g09d2