From 9d6c75262b6919758e781b8333428861a5bf7ede Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:02:49 +0100 Subject: Added learning rate finder --- train_ti.py | 174 ++++++++++++++++++++++++++++++------------------------------ 1 file changed, 87 insertions(+), 87 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 6e30ac3..ab00b60 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,10 +1,8 @@ import argparse import itertools import math -import os import datetime import logging -import json from pathlib import Path import torch @@ -24,6 +22,7 @@ from common import load_text_embeddings, load_text_embedding, load_config from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule +from training.lr import LRFinder from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params from models.clip.prompt import PromptProcessor @@ -172,6 +171,11 @@ def parse_args(): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", + ) parser.add_argument( "--learning_rate", type=float, @@ -225,7 +229,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -447,16 +451,23 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) + if args.find_lr: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + else: + basepath.mkdir(parents=True, exist_ok=True) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) @@ -537,6 +548,9 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -671,7 +685,9 @@ def main(): warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps - if args.lr_scheduler == "one_cycle": + if args.find_lr: + lr_scheduler = None + elif args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, @@ -713,6 +729,63 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == latents).float().mean() + + return loss, acc, bsz + + if args.find_lr: + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) + lr_finder.run() + quit() + # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: @@ -786,54 +859,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) accelerator.backward(loss) @@ -873,33 +899,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) avg_loss_val.update(loss.detach_(), bsz) avg_acc_val.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf