import argparse import itertools import math import os import datetime from pathlib import Path import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify import json import os from data.dreambooth.csv import CSVDataModule from data.dreambooth.prompt import PromptDataset logger = get_logger(__name__) def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A folder containing the training data." ) parser.add_argument( "--identifier", type=str, default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--repeats", type=int, default=100, help="How many times to repeat the training data.") parser.add_argument( "--output_dir", type=str, default="dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) parser.add_argument( "--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", type=int, default=5000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--local_rank", type=int, default=-1, help="For distributed training: local_rank" ) parser.add_argument( "--checkpoint_frequency", type=int, default=500, help="How often to save a checkpoint and sample image", ) parser.add_argument( "--sample_image_size", type=int, default=512, help="Size of sample images", ) parser.add_argument( "--stable_sample_batches", type=int, default=1, help="Number of fixed seed sample batches to generate per checkpoint", ) parser.add_argument( "--random_sample_batches", type=int, default=1, help="Number of random seed sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=50, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--instance_prompt", type=str, default=None, help="The prompt with identifier specifing the instance", ) parser.add_argument( "--class_data_dir", type=str, default=None, required=False, help="A folder containing the training data of class images.", ) parser.add_argument( "--class_prompt", type=str, default=None, help="The prompt to specify images in the same class as provided intance images.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior perservation loss.", ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--num_class_images", type=int, default=100, help=( "Minimal class images for prior perversation loss. If not have enough images, additional images will be" " sampled with class_prompt." ), ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.instance_prompt is None: raise ValueError("You must specify --instance_prompt") if args.identifier is None: raise ValueError("You must specify --identifier") if args.output_dir is None: raise ValueError("You must specify --output_dir") if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify --class_data_dir") if args.class_prompt is None: raise ValueError("You must specify --class_prompt") return args def freeze_params(params): for param in params: param.requires_grad = False def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid class Checkpointer: def __init__( self, datamodule, accelerator, vae, unet, tokenizer, text_encoder, output_dir, sample_image_size, random_sample_batches, sample_batch_size, stable_sample_batches, seed ): self.datamodule = datamodule self.accelerator = accelerator self.vae = vae self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed self.random_sample_batches = random_sample_batches self.sample_batch_size = sample_batch_size self.stable_sample_batches = stable_sample_batches @torch.no_grad() def checkpoint(self): print("Saving model...") unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( text_encoder=self.text_encoder, vae=self.vae, unet=self.accelerator.unwrap_model(self.unet), tokenizer=self.tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), safety_checker=NoCheck(), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.enable_attention_slicing() pipeline.save_pretrained(f"{self.output_dir}/model") del unwrapped del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = f"{self.output_dir}/samples/{mode}" os.makedirs(samples_path, exist_ok=True) checker = NoCheck() unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( text_encoder=self.text_encoder, vae=self.vae, unet=unwrapped, tokenizer=self.tokenizer, scheduler=LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ), safety_checker=NoCheck(), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() data = { "training": self.datamodule.train_dataloader(), "validation": self.datamodule.val_dataloader(), }[mode] if mode == "validation" and self.stable_sample_batches > 0 and step > 0: stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), ) all_samples = [] filename = f"stable_step_%d.png" % (step) data_enum = enumerate(data) # Generate and save stable samples for i in range(0, self.stable_sample_batches): prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, height=self.sample_image_size, latents=stable_latents[:len(prompt)], width=self.sample_image_size, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, output_type='pil' )["sample"] all_samples += samples del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) image_grid.save(f"{samples_path}/{filename}") del all_samples del image_grid del stable_latents all_samples = [] filename = f"step_%d.png" % (step) data_enum = enumerate(data) # Generate and save random samples for i in range(0, self.random_sample_batches): prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, height=self.sample_image_size, width=self.sample_image_size, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, output_type='pil' )["sample"] all_samples += samples del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) image_grid.save(f"{samples_path}/{filename}") del all_samples del image_grid del checker del unwrapped del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" os.makedirs(basepath, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with accelerator.autocast(): images = pipeline(example["prompt"]).images for i, image in enumerate(images): image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path + '/tokenizer' ) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path + '/text_encoder', ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path + '/vae', ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path + '/unet', ) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() # slice_size = unet.config.attention_head_dim // 2 # unet.set_attention_slice(slice_size) # Freeze text_encoder and vae freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Initialize the optimizer optimizer = optimizer_class( unet.parameters(), # only optimize unet lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" ) def collate_fn(examples): prompts = [example["prompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation if args.with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { "prompts": prompts, "input_ids": input_ids, "pixel_values": pixel_values, } return batch datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, size=args.resolution, identifier=args.identifier, repeats=args.repeats, center_crop=args.center_crop, collate_fn=collate_fn) datamodule.prepare_data() datamodule.setup() train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() checkpointer = Checkpointer( datamodule=datamodule, accelerator=accelerator, vae=vae, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, random_sample_batches=args.random_sample_batches, stable_sample_batches=args.stable_sample_batches, seed=args.seed ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil( (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) # Move vae and unet to device text_encoder.to(accelerator.device) vae.to(accelerator.device) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil( args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. global_step = 0 min_val_loss = np.inf checkpointer.save_samples( "validation", 0, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) local_progress_bar.set_description("Steps") progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Global steps") try: for epoch in range(args.num_train_epochs): local_progress_bar.reset() unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): with accelerator.autocast(): # Convert images to latent space with torch.no_grad(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning with torch.no_grad(): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() # Compute prior loss prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() optimizer.zero_grad() loss = loss.detach().item() train_loss += loss # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: local_progress_bar.update(1) progress_bar.update(1) global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: local_progress_bar.clear() progress_bar.clear() checkpointer.save_samples( "training", global_step, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break train_loss /= len(train_dataloader) unet.eval() val_loss = 0.0 for step, batch in enumerate(val_dataloader): with torch.no_grad(), accelerator.autocast(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(batch["input_ids"])[0] noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) if args.with_prior_preservation: noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss if accelerator.sync_gradients: local_progress_bar.update(1) progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) local_progress_bar.clear() progress_bar.clear() if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss checkpointer.save_samples( "validation", global_step, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint() accelerator.end_training() quit() if __name__ == "__main__": main()