From adc52fb8821a496bc8d78235bf10466b39df03e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 19:19:52 +0100 Subject: Updates --- models/clip/embeddings.py | 11 +++ models/clip/tokenizer.py | 76 ++++++++++++---- train_dreambooth.py | 228 +++++++++++++++++++++++++--------------------- train_ti.py | 51 +++++------ training/lr.py | 6 ++ training/optimization.py | 2 +- 6 files changed, 227 insertions(+), 147 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings + + +def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: + text_encoder.text_model.embeddings.make_permanent() + + text_embeddings = CLIPTextEmbeddings(text_encoder.config) + text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding + text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding + text_encoder.text_model.embeddings = text_embeddings + + return text_embeddings diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -1,11 +1,54 @@ import copy -from typing import NamedTuple, Union +from typing import NamedTuple, Union, Literal import numpy as np from transformers import CLIPTokenizer +def shuffle_all(tokens: list[int]): + if len(tokens) >= 2: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) + return tokens + + +def shuffle_leading(tokens: list[int]): + if len(tokens) >= 3: + subtokens = tokens[:-1] + np.random.shuffle(subtokens) + tokens = subtokens + tokens[-1:] + return tokens + + +def shuffle_trailing(tokens: list[int]): + if len(tokens) >= 3: + subtokens = tokens[1:] + np.random.shuffle(subtokens) + tokens = tokens[:1] + subtokens + return tokens + + +def shuffle_between(tokens: list[int]): + if len(tokens) >= 4: + subtokens = tokens[1:-1] + np.random.shuffle(subtokens) + tokens = tokens[:1] + subtokens + tokens[-1:] + return tokens + + +def shuffle_none(tokens: list[int]): + return tokens + + +def shuffle_auto(tokens: list[int]): + if len(tokens) >= 4: + return shuffle_between(tokens) + if len(tokens) >= 3: + return shuffle_trailing(tokens) + return shuffle_all(tokens) + + class MultiCLIPTokenizerItem(NamedTuple): token: str ids: list[int] @@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} - self.vector_shuffle = False - - def set_use_vector_shuffle(self, enable: bool): - self.vector_shuffle = enable + self.vector_shuffle = shuffle_none + + def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): + if algorithm == "leading": + self.vector_shuffle = shuffle_leading + elif algorithm == "trailing": + self.vector_shuffle = shuffle_trailing + elif algorithm == "between": + self.vector_shuffle = shuffle_between + elif algorithm == "auto": + self.vector_shuffle = shuffle_auto + elif algorithm == True or algorithm == "all": + self.vector_shuffle = shuffle_all + else: + self.vector_shuffle = shuffle_none def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): @@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int): - if id in self.token_map: - tokens = self.token_map[id] - - if self.vector_shuffle and len(tokens) > 2: - subtokens = tokens[1:-1] - np.random.shuffle(subtokens) - tokens = tokens[:1] + subtokens + tokens[-1:] - - return tokens - else: - return [id] + return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] def expand_ids(self, ids: list[int]): return [ diff --git a/train_dreambooth.py b/train_dreambooth.py index 1ebcfe3..b07de31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -3,7 +3,6 @@ import itertools import math import datetime import logging -import json from pathlib import Path import torch @@ -15,18 +14,21 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup +import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from common import load_text_embeddings, load_config +from common import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule -from training.ti import patch_trainable_embeddings +from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args +from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings from models.clip.prompt import PromptProcessor +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -105,6 +107,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_shuffle", + type=str, + default="auto", + help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', + ) parser.add_argument( "--num_class_images", type=int, @@ -193,13 +201,12 @@ def parse_args(): help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( - "--learning_rate_unet", - type=float, - default=2e-6, - help="Initial learning rate (after the potential warmup period) to use.", + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate_text", + "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", @@ -546,9 +553,9 @@ def main(): # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -558,6 +565,8 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -576,46 +585,42 @@ def main(): device=accelerator.device ) - # Freeze text_encoder and vae - vae.requires_grad_(False) + embeddings = patch_managed_embeddings(text_encoder) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") + + added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = torch.stack([ - torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token - ]) - - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - print(f"Added {num_added_tokens} new tokens.") + ] - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) + embeddings.resize(len(tokenizer)) - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): + embeddings.add_embed(new_token.ids, init_ids) - token_embeds = text_encoder.get_input_embeddings().weight.data - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings + print(f"Added {len(new_tokens)} new tokens.") else: placeholder_token_id = [] + vae.requires_grad_(False) + if args.train_text_encoder: print(f"Training entire text encoder.") + + unpatch_managed_embeddings(text_encoder) else: print(f"Training added text embeddings") - patch_trainable_embeddings(text_encoder, placeholder_token_id) - text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) @@ -624,15 +629,14 @@ def main(): prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: - args.learning_rate_unet = ( - args.learning_rate_unet * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes - ) - args.learning_rate_text = ( - args.learning_rate_text * args.gradient_accumulation_steps * + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -647,20 +651,19 @@ def main(): if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() else: - text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() + text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() # Initialize the optimizer optimizer = optimizer_class( [ { 'params': unet.parameters(), - 'lr': args.learning_rate_unet, }, { 'params': text_encoder_params_to_optimize, - 'lr': args.learning_rate_text, } ], + lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, @@ -824,6 +827,58 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == target).float().mean() + + return loss, acc, bsz + # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -836,6 +891,15 @@ def main(): config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("dreambooth", config=config) + if args.find_lr: + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder.run(min_lr=1e-4) + + plt.savefig(basepath.joinpath("lr.png")) + plt.close() + + quit() + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -893,58 +957,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def loop(batch): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == target).float().mean() - - return loss, acc, bsz - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -993,8 +1005,7 @@ def main(): "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), - "lr/unet": lr_scheduler.get_last_lr()[0], - "lr/text": lr_scheduler.get_last_lr()[1] + "lr": lr_scheduler.get_last_lr()[0] } if args.use_ema: logs["ema_decay"] = 1 - ema_unet.decay @@ -1011,12 +1022,21 @@ def main(): unet.eval() text_encoder.eval() + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loop(batch) - avg_loss_val.update(loss.detach_(), bsz) - avg_acc_val.update(acc.detach_(), bsz) + loss = loss.detach_() + acc = acc.detach_() + + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) + + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) @@ -1029,20 +1049,20 @@ def main(): } local_progress_bar.set_postfix(**logs) - accelerator.log({ - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - }, step=global_step) + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - max_acc_val = avg_acc_val.avg.item() - if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + max_acc_val = avg_acc_val.avg.item() + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples(global_step, args.sample_steps) diff --git a/train_ti.py b/train_ti.py index 20a3190..775b918 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import datetime import logging @@ -155,6 +154,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_shuffle", + type=str, + default="auto", + help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -245,7 +250,7 @@ def parse_args(): parser.add_argument( "--lr_annealing_exp", type=int, - default=2, + default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( @@ -502,20 +507,14 @@ def main(): basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) - if args.find_lr: - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) - else: - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) @@ -534,7 +533,7 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - tokenizer.set_use_vector_shuffle(True) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) @@ -585,7 +584,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e3 + args.learning_rate = 1e2 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -830,15 +829,6 @@ def main(): return loss, acc, bsz - if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) - lr_finder.run(min_lr=1e-4) - - plt.savefig(basepath.joinpath("lr.png")) - plt.close() - - quit() - # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -852,6 +842,15 @@ def main(): config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("textual_inversion", config=config) + if args.find_lr: + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder.run(min_lr=1e-4) + + plt.savefig(basepath.joinpath("lr.png")) + plt.close() + + quit() + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..3abd2f2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -102,6 +102,12 @@ class LRFinder(): losses.append(loss) accs.append(acc) + self.accelerator.log({ + "loss": loss, + "acc": acc, + "lr": lr, + }, step=epoch) + progress_bar.set_postfix({ "loss": loss, "loss/best": best_loss, diff --git a/training/optimization.py b/training/optimization.py index 3340544..a79944f 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -15,7 +15,7 @@ def get_one_cycle_schedule( warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, - annealing_exp: int = 2, + annealing_exp: int = 1, min_lr: int = 0.04, mid_point: int = 0.3, last_epoch: int = -1 -- cgit v1.2.3-70-g09d2