From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- environment.yaml | 2 +- train_dreambooth.py | 10 +- train_ti.py | 21 +++- training/functional.py | 35 ++++-- training/lr.py | 266 +++++----------------------------------- training/optimization.py | 19 +++ training/strategy/dreambooth.py | 4 +- training/strategy/ti.py | 5 +- training/util.py | 146 +--------------------- 9 files changed, 111 insertions(+), 397 deletions(-) diff --git a/environment.yaml b/environment.yaml index 03345c6..c992759 100644 --- a/environment.yaml +++ b/environment.yaml @@ -25,4 +25,4 @@ dependencies: - test-tube>=0.7.5 - transformers==4.25.1 - triton==2.0.0.dev20221202 - - xformers==0.0.16rc403 + - xformers==0.0.16.dev430 diff --git a/train_dreambooth.py b/train_dreambooth.py index 9c1e41c..a70c80e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -16,6 +16,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, get_models +from training.lr import plot_metrics from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler from training.util import save_args @@ -524,6 +525,10 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e-6 + args.lr_scheduler = "exponential_growth" + if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -602,11 +607,12 @@ def main(): warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + end_lr=1e2, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) - trainer( + metrics = trainer( strategy=dreambooth_strategy, project="dreambooth", train_dataloader=datamodule.train_dataloader, @@ -634,6 +640,8 @@ def main(): sample_image_size=args.sample_image_size, ) + plot_metrics(metrics, output_dir.joinpath("lr.png")) + if __name__ == "__main__": main() diff --git a/train_ti.py b/train_ti.py index 451b61b..c118aab 100644 --- a/train_ti.py +++ b/train_ti.py @@ -15,6 +15,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models +from training.lr import plot_metrics from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -60,6 +61,12 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--skip_first", + type=int, + default=0, + help="Tokens to skip training for.", + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -407,7 +414,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=10, + default=1e0, type=float, help="Embedding decay factor." ) @@ -543,6 +550,10 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e-5 + args.lr_scheduler = "exponential_growth" + if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -596,6 +607,9 @@ def main(): ) def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + if i < args.skip_first: + return + if len(placeholder_tokens) == 1: sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") else: @@ -656,11 +670,12 @@ def main(): warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + end_lr=1e3, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) - trainer( + metrics = trainer( project="textual_inversion", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, @@ -672,6 +687,8 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) + plot_metrics(metrics, output_dir.joinpath("lr.png")) + if args.simultaneous: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -7,7 +7,6 @@ from pathlib import Path import itertools import torch -import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -373,8 +372,12 @@ def train_loop( avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() - max_acc = 0.0 - max_acc_val = 0.0 + best_acc = 0.0 + best_acc_val = 0.0 + + lrs = [] + losses = [] + accs = [] local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -457,6 +460,8 @@ def train_loop( accelerator.wait_for_everyone() + lrs.append(lr_scheduler.get_last_lr()[0]) + on_after_epoch(lr_scheduler.get_last_lr()[0]) if val_dataloader is not None: @@ -498,18 +503,24 @@ def train_loop( global_progress_bar.clear() if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: + if avg_acc_val.avg.item() > best_acc_val: accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() + best_acc_val = avg_acc_val.avg.item() + + losses.append(avg_loss_val.avg.item()) + accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: - if avg_acc.avg.item() > max_acc: + if avg_acc.avg.item() > best_acc: accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") - max_acc = avg_acc.avg.item() + best_acc = avg_acc.avg.item() + + losses.append(avg_loss.avg.item()) + accs.append(avg_acc.avg.item()) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -523,6 +534,8 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt + return lrs, losses, accs + def train( accelerator: Accelerator, @@ -582,7 +595,7 @@ def train( if accelerator.is_main_process: accelerator.init_trackers(project) - train_loop( + metrics = train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -598,3 +611,5 @@ def train( accelerator.end_training() accelerator.free_memory() + + return metrics diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,238 +1,36 @@ -import math -from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union -from functools import partial +from pathlib import Path import matplotlib.pyplot as plt -import numpy as np -import torch -from torch.optim.lr_scheduler import LambdaLR -from tqdm.auto import tqdm -from training.functional import TrainingCallbacks -from training.util import AverageMeter - -def noop(*args, **kwards): - pass - - -def noop_ctx(*args, **kwards): - return nullcontext() - - -class LRFinder(): - def __init__( - self, - accelerator, - optimizer, - train_dataloader, - val_dataloader, - loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - callbacks: TrainingCallbacks = TrainingCallbacks() - ): - self.accelerator = accelerator - self.model = callbacks.on_model() - self.optimizer = optimizer - self.train_dataloader = train_dataloader - self.val_dataloader = val_dataloader - self.loss_fn = loss_fn - self.callbacks = callbacks - - # self.model_state = copy.deepcopy(model.state_dict()) - # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - - def run( - self, - end_lr, - skip_start: int = 10, - skip_end: int = 5, - num_epochs: int = 100, - num_train_batches: int = math.inf, - num_val_batches: int = math.inf, - smooth_f: float = 0.05, - ): - best_loss = None - best_acc = None - - lrs = [] - losses = [] - accs = [] - - lr_scheduler = get_exponential_schedule( - self.optimizer, - end_lr, - num_epochs * min(num_train_batches, len(self.train_dataloader)) - ) - - steps = min(num_train_batches, len(self.train_dataloader)) - steps += min(num_val_batches, len(self.val_dataloader)) - steps *= num_epochs - - progress_bar = tqdm( - range(steps), - disable=not self.accelerator.is_local_main_process, - dynamic_ncols=True - ) - progress_bar.set_description("Epoch X / Y") - - self.callbacks.on_prepare() - - on_train = self.callbacks.on_train - on_before_optimize = self.callbacks.on_before_optimize - on_after_optimize = self.callbacks.on_after_optimize - on_eval = self.callbacks.on_eval - - for epoch in range(num_epochs): - progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - self.model.train() - - with on_train(epoch): - for step, batch in enumerate(self.train_dataloader): - if step >= num_train_batches: - break - - with self.accelerator.accumulate(self.model): - loss, acc, bsz = self.loss_fn(step, batch) - - self.accelerator.backward(loss) - - on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) - - self.optimizer.step() - lr_scheduler.step() - self.optimizer.zero_grad(set_to_none=True) - - if self.accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - progress_bar.update(1) - - self.model.eval() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(self.val_dataloader): - if step >= num_val_batches: - break - - loss, acc, bsz = self.loss_fn(step, batch, True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - progress_bar.update(1) - - loss = avg_loss.avg.item() - acc = avg_acc.avg.item() - - if epoch == 0: - best_loss = loss - best_acc = acc - else: - if smooth_f > 0: - loss = smooth_f * loss + (1 - smooth_f) * losses[-1] - acc = smooth_f * acc + (1 - smooth_f) * accs[-1] - if loss < best_loss: - best_loss = loss - if acc > best_acc: - best_acc = acc - - lr = lr_scheduler.get_last_lr()[0] - - lrs.append(lr) - losses.append(loss) - accs.append(acc) - - self.accelerator.log({ - "loss": loss, - "acc": acc, - "lr": lr, - }, step=epoch) - - progress_bar.set_postfix({ - "loss": loss, - "loss/best": best_loss, - "acc": acc, - "acc/best": best_acc, - "lr": lr, - }) - - # self.model.load_state_dict(self.model_state) - # self.optimizer.load_state_dict(self.optimizer_state) - - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - - fig, ax_loss = plt.subplots() - ax_acc = ax_loss.twinx() - - ax_loss.plot(lrs, losses, color='red') - ax_loss.set_xscale("log") - ax_loss.set_xlabel(f"Learning rate") - ax_loss.set_ylabel("Loss") - - ax_acc.plot(lrs, accs, color='blue') - ax_acc.set_xscale("log") - ax_acc.set_ylabel("Accuracy") - - print("LR suggestion: steepest gradient") - min_grad_idx = None - - try: - min_grad_idx = np.gradient(np.array(losses)).argmin() - except ValueError: - print( - "Failed to compute the gradients, there might not be enough points." - ) - - try: - max_val_idx = np.array(accs).argmax() - except ValueError: - print( - "Failed to compute the gradients, there might not be enough points." - ) - - if min_grad_idx is not None: - print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) - ax_loss.scatter( - lrs[min_grad_idx], - losses[min_grad_idx], - s=75, - marker="o", - color="red", - zorder=3, - label="steepest gradient", - ) - ax_loss.legend() - - if max_val_idx is not None: - print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) - ax_acc.scatter( - lrs[max_val_idx], - accs[max_val_idx], - s=75, - marker="o", - color="blue", - zorder=3, - label="maximum", - ) - ax_acc.legend() - - -def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): - def lr_lambda(base_lr: float, current_epoch: int): - return (end_lr / base_lr) ** (current_epoch / num_epochs) - - lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] - - return LambdaLR(optimizer, lr_lambdas, last_epoch) +def plot_metrics( + metrics: tuple[list[float], list[float], list[float]], + output_file: Path, + skip_start: int = 10, + skip_end: int = 5, +): + lrs, losses, accs = metrics + + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + + fig, ax_loss = plt.subplots() + ax_acc = ax_loss.twinx() + + ax_loss.plot(lrs, losses, color='red') + ax_loss.set_xscale("log") + ax_loss.set_xlabel(f"Learning rate") + ax_loss.set_ylabel("Loss") + + ax_acc.plot(lrs, accs, color='blue') + ax_acc.set_xscale("log") + ax_acc.set_ylabel("Accuracy") + + plt.savefig(output_file, dpi=300) + plt.close() diff --git a/training/optimization.py b/training/optimization.py index 6dee4bc..6c9a35d 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -87,6 +87,15 @@ def get_one_cycle_schedule( return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): + def lr_lambda(base_lr: float, current_step: int): + return (end_lr / base_lr) ** (current_step / num_training_steps) + + lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] + + return LambdaLR(optimizer, lr_lambdas, last_epoch) + + def get_scheduler( id: str, optimizer: torch.optim.Optimizer, @@ -97,6 +106,7 @@ def get_scheduler( annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, + end_lr: float = 1e3, cycles: int = 1, train_epochs: int = 100, warmup_epochs: int = 10, @@ -117,6 +127,15 @@ def get_scheduler( annealing_exp=annealing_exp, min_lr=min_lr, ) + elif id == "exponential_growth": + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + + lr_scheduler = get_exponential_growing_schedule( + optimizer=optimizer, + end_lr=end_lr, + num_training_steps=num_training_steps, + ) elif id == "cosine_with_restarts": if cycles is None: cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 1277939..e88bf90 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -193,9 +193,7 @@ def dreambooth_prepare( unet: UNet2DConditionModel, *args ): - prep = [text_encoder, unet] + list(args) - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + return accelerator.prepare(text_encoder, unet, *args) dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a76f98..14bdafd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -176,10 +176,9 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - prep = [text_encoder] + list(args) - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + prepped = accelerator.prepare(text_encoder, *args) unet.to(accelerator.device, dtype=weight_dtype) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + return (prepped[0], unet) + prepped[1:] textual_inversion_strategy = TrainingStrategy( diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py @@ -6,6 +6,8 @@ from contextlib import contextmanager import torch +from diffusers.training_utils import EMAModel as EMAModel_ + def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} @@ -30,149 +32,7 @@ class AverageMeter: self.avg = self.sum / self.count -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - update_after_step: int = 0, - inv_gamma: float = 1.0, - power: float = 2 / 3, - min_value: float = 0.0, - max_value: float = 0.9999, - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step: int): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - self.decay = self.get_decay(self.optimization_step) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.mul_(self.decay) - s_param.add_(param.data, alpha=1 - self.decay) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - +class EMAModel(EMAModel_): @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): parameters = list(parameters) -- cgit v1.2.3-70-g09d2