From 9d6c75262b6919758e781b8333428861a5bf7ede Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:02:49 +0100 Subject: Added learning rate finder --- environment.yaml | 1 + train_dreambooth.py | 129 ++++++++++++++++---------------------- train_ti.py | 174 ++++++++++++++++++++++++++-------------------------- training/lr.py | 115 ++++++++++++++++++++++++++++++++++ 4 files changed, 257 insertions(+), 162 deletions(-) create mode 100644 training/lr.py diff --git a/environment.yaml b/environment.yaml index 179fa38..c006379 100644 --- a/environment.yaml +++ b/environment.yaml @@ -5,6 +5,7 @@ channels: - defaults dependencies: - cudatoolkit=11.3 + - matplotlib=3.6.2 - numpy=1.23.4 - pip=22.3.1 - python=3.9.15 diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -843,6 +843,58 @@ def main(): ) global_progress_bar.set_description("Total progress") + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == latents).float().mean() + + return loss, acc, bsz + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -859,54 +911,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) accelerator.backward(loss) @@ -960,33 +965,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) avg_loss_val.update(loss.detach_(), bsz) avg_acc_val.update(acc.detach_(), bsz) diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,10 +1,8 @@ import argparse import itertools import math -import os import datetime import logging -import json from pathlib import Path import torch @@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule +from training.lr import LRFinder from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params from models.clip.prompt import PromptProcessor @@ -172,6 +171,11 @@ def parse_args(): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", + ) parser.add_argument( "--learning_rate", type=float, @@ -225,7 +229,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -447,16 +451,23 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) + if args.find_lr: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + else: + basepath.mkdir(parents=True, exist_ok=True) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) @@ -537,6 +548,9 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -671,7 +685,9 @@ def main(): warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps - if args.lr_scheduler == "one_cycle": + if args.find_lr: + lr_scheduler = None + elif args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, @@ -713,6 +729,63 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == latents).float().mean() + + return loss, acc, bsz + + if args.find_lr: + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) + lr_finder.run() + quit() + # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -786,54 +859,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) accelerator.backward(loss) @@ -873,33 +899,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) avg_loss_val.update(loss.detach_(), bsz) avg_acc_val.update(acc.detach_(), bsz) diff --git a/training/lr.py b/training/lr.py new file mode 100644 index 0000000..dd37baa --- /dev/null +++ b/training/lr.py @@ -0,0 +1,115 @@ +import numpy as np +from torch.optim.lr_scheduler import LambdaLR +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +from training.util import AverageMeter + + +class LRFinder(): + def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): + self.accelerator = accelerator + self.model = model + self.optimizer = optimizer + self.train_dataloader = train_dataloader + self.loss_fn = loss_fn + + def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): + best_loss = None + lrs = [] + losses = [] + + lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) + + progress_bar = tqdm( + range(num_epochs * num_steps), + disable=not self.accelerator.is_local_main_process, + dynamic_ncols=True + ) + progress_bar.set_description("Epoch X / Y") + + for epoch in range(num_epochs): + progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") + + avg_loss = AverageMeter() + + for step, batch in enumerate(self.train_dataloader): + with self.accelerator.accumulate(self.model): + loss, acc, bsz = self.loss_fn(batch) + + self.accelerator.backward(loss) + + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + avg_loss.update(loss.detach_(), bsz) + + if step >= num_steps: + break + + if self.accelerator.sync_gradients: + progress_bar.update(1) + + lr_scheduler.step() + + loss = avg_loss.avg.item() + if epoch == 0: + best_loss = loss + else: + if smooth_f > 0: + loss = smooth_f * loss + (1 - smooth_f) * losses[-1] + if loss < best_loss: + best_loss = loss + + lr = lr_scheduler.get_last_lr()[0] + + lrs.append(lr) + losses.append(loss) + + progress_bar.set_postfix({ + "loss": loss, + "best": best_loss, + "lr": lr, + }) + + if loss > diverge_th * best_loss: + print("Stopping early, the loss has diverged") + break + + fig, ax = plt.subplots() + ax.plot(lrs, losses) + + print("LR suggestion: steepest gradient") + min_grad_idx = None + try: + min_grad_idx = (np.gradient(np.array(losses))).argmin() + except ValueError: + print( + "Failed to compute the gradients, there might not be enough points." + ) + if min_grad_idx is not None: + print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) + ax.scatter( + lrs[min_grad_idx], + losses[min_grad_idx], + s=75, + marker="o", + color="red", + zorder=3, + label="steepest gradient", + ) + ax.legend() + + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if fig is not None: + plt.show() + + +def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): + def lr_lambda(current_epoch: int): + return (current_epoch / num_epochs) ** 5 + + return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2