From 46b6c09a18b41edff77c6881529b66733d788abe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 21:28:52 +0200 Subject: Dreambooth: Generate specialized class images from input prompts --- data/dreambooth/csv.py | 112 +++++++++++++--------------- data/dreambooth/prompt.py | 4 +- data/textual_inversion/csv.py | 3 +- dreambooth.py | 168 ++++++++++++++++++------------------------ textual_inversion.py | 6 +- 5 files changed, 129 insertions(+), 164 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index c0b0067..4ebdc13 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule): batch_size, data_file, tokenizer, - instance_prompt, - class_data_root=None, - class_prompt=None, + instance_identifier, + class_identifier=None, size=512, repeats=100, interpolation="bicubic", - identifier="*", center_crop=False, valid_set_size=None, generator=None, @@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent + self.class_root = self.data_root.joinpath("db_cls") + self.class_root.mkdir(parents=True, exist_ok=True) + self.tokenizer = tokenizer - self.instance_prompt = instance_prompt - self.class_data_root = class_data_root - self.class_prompt = class_prompt + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier self.size = size self.repeats = repeats - self.identifier = identifier self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) - image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] + class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) - self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [(i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x"] def setup(self, stage=None): - valid_set_size = int(len(self.data_full) * 0.2) + valid_set_size = int(len(self.data) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) - train_set_size = len(self.data_full) - valid_set_size - - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, class_prompt=self.class_prompt, - size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) - val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, - size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, batch_size=self.batch_size) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, + valid_set_size = max(valid_set_size, 1) + train_set_size = len(self.data) - valid_set_size + + self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, + instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + val_dataset = CSVDataset(self.data_val, self.tokenizer, + instance_identifier=self.instance_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): @@ -85,39 +90,23 @@ class CSVDataset(Dataset): def __init__(self, data, tokenizer, - instance_prompt, - class_data_root=None, - class_prompt=None, + instance_identifier, + class_identifier=None, size=512, repeats=1, interpolation="bicubic", - identifier="*", center_crop=False, - batch_size=1, ): self.data = data self.tokenizer = tokenizer - self.instance_prompt = instance_prompt - self.identifier = identifier - self.batch_size = batch_size + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - - self.class_images = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images) - self._length = max(self.num_class_images, self.num_instance_images) - - self.class_prompt = class_prompt - else: - self.class_data_root = None - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, @@ -134,46 +123,49 @@ class CSVDataset(Dataset): ) def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size + return self._length def get_example(self, i): - image_path, prompt, nprompt = self.data[i % self.num_instance_images] + instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - if image_path in self.cache: - return self.cache[image_path] + if instance_image_path in self.cache: + return self.cache[instance_image_path] example = {} - instance_image = Image.open(image_path) + example["prompts"] = prompt + example["nprompts"] = nprompt + + instance_image = Image.open(instance_image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - prompt = prompt.format(self.identifier) + instance_prompt = prompt.format(self.instance_identifier) - example["prompts"] = prompt - example["nprompts"] = nprompt example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, + instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_data_root: - class_image = Image.open(self.class_images[i % self.num_class_images]) + if self.class_identifier: + class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") + class_prompt = prompt.format(self.class_identifier) + example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, + class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - self.cache[image_path] = example + self.cache[instance_image_path] = example return example def __getitem__(self, i): @@ -185,7 +177,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_data_root: + if self.class_identifier: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py @@ -2,8 +2,9 @@ from torch.utils.data import Dataset class PromptDataset(Dataset): - def __init__(self, prompt, num_samples): + def __init__(self, prompt, nprompt, num_samples): self.prompt = prompt + self.nprompt = nprompt self.num_samples = num_samples def __len__(self): @@ -12,5 +13,6 @@ class PromptDataset(Dataset): def __getitem__(self, index): example = {} example["prompt"] = self.prompt + example["nprompt"] = self.nprompt example["index"] = index return example diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 852b1cb..4c5e27e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = max(valid_set_size, 1) train_set_size = len(self.data_full) - valid_set_size self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) - val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, + val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) diff --git a/dreambooth.py b/dreambooth.py index 9d6b8d6..2fe89ec 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -13,13 +13,12 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler -from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json @@ -56,7 +55,13 @@ def parse_args(): help="A folder containing the training data." ) parser.add_argument( - "--identifier", + "--instance_identifier", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", type=str, default=None, help="A token to use as a placeholder for the concept.", @@ -217,12 +222,6 @@ def parse_args(): default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - help="The prompt with identifier specifing the instance", - ) parser.add_argument( "--class_data_dir", type=str, @@ -230,12 +229,6 @@ def parse_args(): required=False, help="A folder containing the training data of class images.", ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided intance images.", - ) parser.add_argument( "--prior_loss_weight", type=float, @@ -254,15 +247,6 @@ def parse_args(): type=float, help="Max gradient norm." ) - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior perversation loss. If not have enough images, additional images will be" - " sampled with class_prompt." - ), - ) parser.add_argument( "--config", type=str, @@ -286,21 +270,12 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") - if args.instance_prompt is None: - raise ValueError("You must specify --instance_prompt") - - if args.identifier is None: - raise ValueError("You must specify --identifier") + if args.instance_identifier is None: + raise ValueError("You must specify --instance_identifier") if args.output_dir is None: raise ValueError("You must specify --output_dir") - if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify --class_data_dir") - if args.class_prompt is None: - raise ValueError("You must specify --class_prompt") - return args @@ -443,7 +418,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -488,47 +463,6 @@ def main(): freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) - # Generate class images, if necessary - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - class_images_dir.mkdir(parents=True, exist_ok=True) - cur_class_images = len(list(class_images_dir.iterdir())) - - if cur_class_images < args.num_class_images: - scheduler = EulerAScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.enable_attention_slicing() - pipeline.set_progress_bar_config(disable=True) - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - - for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): - with accelerator.autocast(): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -564,6 +498,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -579,6 +514,7 @@ def main(): batch = { "prompts": prompts, + "nprompts": nprompts, "input_ids": input_ids, "pixel_values": pixel_values, } @@ -588,11 +524,9 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, + instance_identifier=args.instance_identifier, + class_identifier=args.class_identifier, size=args.resolution, - identifier=args.identifier, repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, @@ -601,6 +535,46 @@ def main(): datamodule.prepare_data() datamodule.setup() + if args.class_identifier: + missing_data = [item for item in datamodule.data if not item[1].exists()] + + if len(missing_data) != 0: + batched_data = [missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size)] + + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.enable_attention_slicing() + + for batch in batched_data: + image_name = [p[1] for p in batch] + prompt = [p[2] for p in batch] + nprompt = [p[3] for p in batch] + + with accelerator.autocast(): + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + num_inference_steps=args.sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() @@ -718,23 +692,22 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - with accelerator.autocast(): - if args.with_prior_preservation: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) if accelerator.sync_gradients: @@ -786,8 +759,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - with accelerator.autocast(): - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss diff --git a/textual_inversion.py b/textual_inversion.py index 5fc2338..4c4da29 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -694,8 +694,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - with accelerator.autocast(): - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) @@ -766,8 +765,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - with accelerator.autocast(): - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-70-g09d2