From 5c115a212e40ff177c734351601f9babe29419ce Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 22:05:25 +0100 Subject: Added EMA to TI --- models/clip/embeddings.py | 6 +-- train_dreambooth.py | 2 +- train_ti.py | 59 ++++++++++++++++++++++++++- training/util.py | 100 +++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 157 insertions(+), 10 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) - def make_permanent(self): + def persist(self): self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] self.temp_token_ids = torch.tensor([], dtype=torch.long) @@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds = self.token_embedding(input_ids) + + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) embeds[mask] = self.temp_token_embedding(input_ids)[mask] return embeds diff --git a/train_dreambooth.py b/train_dreambooth.py index 4d1e0a3..c355ea1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -638,7 +638,7 @@ def main(): if args.train_text_encoder: print(f"Training entire text encoder.") - embeddings.make_permanent() + embeddings.persist() text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) else: print(f"Training added text embeddings") diff --git a/train_ti.py b/train_ti.py index 98385dd..dc36e42 100644 --- a/train_ti.py +++ b/train_ti.py @@ -2,6 +2,7 @@ import argparse import math import datetime import logging +import copy from pathlib import Path from functools import partial @@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer @@ -266,6 +267,27 @@ def parse_args(): default=None, help="Minimum learning rate in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=6/7 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase): unet, tokenizer, text_encoder, + ema_embeddings, scheduler, placeholder_token, new_ids, @@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase): self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings self.scheduler = scheduler self.placeholder_token = placeholder_token self.new_ids = new_ids @@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + if self.ema_embeddings is not None: + orig_weights = text_encoder.text_model.embeddings.temp_token_embedding + ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) + self.ema_embeddings.copy_to(ema_weights.parameters()) + text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + for (token, ids) in zip(self.placeholder_token, self.new_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) + if self.ema_embeddings is not None: + text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + del text_encoder @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + if self.ema_embeddings is not None: + orig_weights = text_encoder.text_model.embeddings.temp_token_embedding + ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) + self.ema_embeddings.copy_to(ema_weights.parameters()) + text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + orig_dtype = text_encoder.dtype text_encoder.to(dtype=self.weight_dtype) @@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase): text_encoder.to(dtype=orig_dtype) + if self.ema_embeddings is not None: + text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + del text_encoder del pipeline @@ -567,6 +610,7 @@ def main(): text_encoder.gradient_checkpointing_enable() embeddings = patch_managed_embeddings(text_encoder) + ema_embeddings = None if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) @@ -592,6 +636,14 @@ def main(): print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") + if args.use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) + vae.requires_grad_(False) unet.requires_grad_(False) @@ -788,6 +840,7 @@ def main(): # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) + ema_embeddings.to(accelerator.device) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -883,6 +936,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + ema_embeddings=ema_embeddings, scheduler=checkpoint_scheduler, placeholder_token=args.placeholder_token, new_ids=new_ids, @@ -935,6 +989,9 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + local_progress_bar.update(1) global_progress_bar.update(1) diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py @@ -1,5 +1,6 @@ from pathlib import Path import json +import copy from typing import Iterable import torch @@ -116,18 +117,58 @@ class CheckpointerBase: del generator +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] - self.decay = decay + self.collected_params = None + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 self.optimization_step = 0 + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + @torch.no_grad() def step(self, parameters): parameters = list(parameters) @@ -135,12 +176,12 @@ class EMAModel: self.optimization_step += 1 # Compute the decay factor for the exponential moving average. - value = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, value) + self.decay = self.get_decay(self.optimization_step) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) + s_param.mul_(self.decay) + s_param.add_(param.data, alpha=1 - self.decay) else: s_param.copy_(param) @@ -169,3 +210,52 @@ class EMAModel: p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. + This method is used by accelerate during checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Loads the ExponentialMovingAverage state. + This method is used by accelerate during checkpointing to save the ema state dict. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.optimization_step = state_dict["optimization_step"] + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") -- cgit v1.2.3-54-g00ecf