From adc52fb8821a496bc8d78235bf10466b39df03e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 19:19:52 +0100 Subject: Updates --- train_dreambooth.py | 228 ++++++++++++++++++++++++++++------------------------ 1 file changed, 124 insertions(+), 104 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 1ebcfe3..b07de31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -3,7 +3,6 @@ import itertools import math import datetime import logging -import json from pathlib import Path import torch @@ -15,18 +14,21 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup +import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from common import load_text_embeddings, load_config +from common import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule -from training.ti import patch_trainable_embeddings +from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args +from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings from models.clip.prompt import PromptProcessor +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -105,6 +107,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_shuffle", + type=str, + default="auto", + help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', + ) parser.add_argument( "--num_class_images", type=int, @@ -193,13 +201,12 @@ def parse_args(): help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( - "--learning_rate_unet", - type=float, - default=2e-6, - help="Initial learning rate (after the potential warmup period) to use.", + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate_text", + "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", @@ -546,9 +553,9 @@ def main(): # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -558,6 +565,8 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -576,46 +585,42 @@ def main(): device=accelerator.device ) - # Freeze text_encoder and vae - vae.requires_grad_(False) + embeddings = patch_managed_embeddings(text_encoder) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") + + added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = torch.stack([ - torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token - ]) - - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - print(f"Added {num_added_tokens} new tokens.") + ] - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) + embeddings.resize(len(tokenizer)) - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): + embeddings.add_embed(new_token.ids, init_ids) - token_embeds = text_encoder.get_input_embeddings().weight.data - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings + print(f"Added {len(new_tokens)} new tokens.") else: placeholder_token_id = [] + vae.requires_grad_(False) + if args.train_text_encoder: print(f"Training entire text encoder.") + + unpatch_managed_embeddings(text_encoder) else: print(f"Training added text embeddings") - patch_trainable_embeddings(text_encoder, placeholder_token_id) - text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) @@ -624,15 +629,14 @@ def main(): prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: - args.learning_rate_unet = ( - args.learning_rate_unet * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes - ) - args.learning_rate_text = ( - args.learning_rate_text * args.gradient_accumulation_steps * + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -647,20 +651,19 @@ def main(): if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() else: - text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() + text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() # Initialize the optimizer optimizer = optimizer_class( [ { 'params': unet.parameters(), - 'lr': args.learning_rate_unet, }, { 'params': text_encoder_params_to_optimize, - 'lr': args.learning_rate_text, } ], + lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, @@ -824,6 +827,58 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == target).float().mean() + + return loss, acc, bsz + # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -836,6 +891,15 @@ def main(): config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("dreambooth", config=config) + if args.find_lr: + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder.run(min_lr=1e-4) + + plt.savefig(basepath.joinpath("lr.png")) + plt.close() + + quit() + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -893,58 +957,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def loop(batch): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == target).float().mean() - - return loss, acc, bsz - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -993,8 +1005,7 @@ def main(): "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), - "lr/unet": lr_scheduler.get_last_lr()[0], - "lr/text": lr_scheduler.get_last_lr()[1] + "lr": lr_scheduler.get_last_lr()[0] } if args.use_ema: logs["ema_decay"] = 1 - ema_unet.decay @@ -1011,12 +1022,21 @@ def main(): unet.eval() text_encoder.eval() + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loop(batch) - avg_loss_val.update(loss.detach_(), bsz) - avg_acc_val.update(acc.detach_(), bsz) + loss = loss.detach_() + acc = acc.detach_() + + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) + + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) @@ -1029,20 +1049,20 @@ def main(): } local_progress_bar.set_postfix(**logs) - accelerator.log({ - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - }, step=global_step) + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - max_acc_val = avg_acc_val.avg.item() - if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + max_acc_val = avg_acc_val.avg.item() + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples(global_step, args.sample_steps) -- cgit v1.2.3-54-g00ecf