From 46b6c09a18b41edff77c6881529b66733d788abe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 21:28:52 +0200 Subject: Dreambooth: Generate specialized class images from input prompts --- dreambooth.py | 168 ++++++++++++++++++++++++---------------------------------- 1 file changed, 70 insertions(+), 98 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 9d6b8d6..2fe89ec 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -13,13 +13,12 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler -from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json @@ -56,7 +55,13 @@ def parse_args(): help="A folder containing the training data." ) parser.add_argument( - "--identifier", + "--instance_identifier", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", type=str, default=None, help="A token to use as a placeholder for the concept.", @@ -217,12 +222,6 @@ def parse_args(): default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) - parser.add_argument( - "--instance_prompt", - type=str, - default=None, - help="The prompt with identifier specifing the instance", - ) parser.add_argument( "--class_data_dir", type=str, @@ -230,12 +229,6 @@ def parse_args(): required=False, help="A folder containing the training data of class images.", ) - parser.add_argument( - "--class_prompt", - type=str, - default=None, - help="The prompt to specify images in the same class as provided intance images.", - ) parser.add_argument( "--prior_loss_weight", type=float, @@ -254,15 +247,6 @@ def parse_args(): type=float, help="Max gradient norm." ) - parser.add_argument( - "--num_class_images", - type=int, - default=100, - help=( - "Minimal class images for prior perversation loss. If not have enough images, additional images will be" - " sampled with class_prompt." - ), - ) parser.add_argument( "--config", type=str, @@ -286,21 +270,12 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") - if args.instance_prompt is None: - raise ValueError("You must specify --instance_prompt") - - if args.identifier is None: - raise ValueError("You must specify --identifier") + if args.instance_identifier is None: + raise ValueError("You must specify --instance_identifier") if args.output_dir is None: raise ValueError("You must specify --output_dir") - if args.with_prior_preservation: - if args.class_data_dir is None: - raise ValueError("You must specify --class_data_dir") - if args.class_prompt is None: - raise ValueError("You must specify --class_prompt") - return args @@ -443,7 +418,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -488,47 +463,6 @@ def main(): freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) - # Generate class images, if necessary - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - class_images_dir.mkdir(parents=True, exist_ok=True) - cur_class_images = len(list(class_images_dir.iterdir())) - - if cur_class_images < args.num_class_images: - scheduler = EulerAScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.enable_attention_slicing() - pipeline.set_progress_bar_config(disable=True) - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - - for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): - with accelerator.autocast(): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -564,6 +498,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -579,6 +514,7 @@ def main(): batch = { "prompts": prompts, + "nprompts": nprompts, "input_ids": input_ids, "pixel_values": pixel_values, } @@ -588,11 +524,9 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, + instance_identifier=args.instance_identifier, + class_identifier=args.class_identifier, size=args.resolution, - identifier=args.identifier, repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, @@ -601,6 +535,46 @@ def main(): datamodule.prepare_data() datamodule.setup() + if args.class_identifier: + missing_data = [item for item in datamodule.data if not item[1].exists()] + + if len(missing_data) != 0: + batched_data = [missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size)] + + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.enable_attention_slicing() + + for batch in batched_data: + image_name = [p[1] for p in batch] + prompt = [p[2] for p in batch] + nprompt = [p[3] for p in batch] + + with accelerator.autocast(): + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + num_inference_steps=args.sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() @@ -718,23 +692,22 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - with accelerator.autocast(): - if args.with_prior_preservation: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) if accelerator.sync_gradients: @@ -786,8 +759,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - with accelerator.autocast(): - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf