from pathlib import Path import json import copy from typing import Iterable, Union from contextlib import contextmanager import torch from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) class AverageMeter: def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], update_after_step: int = 0, inv_gamma: float = 1.0, power: float = 2 / 3, min_value: float = 0.0, max_value: float = 0.9999, ): """ @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. min_value (float): The minimum EMA decay rate. Default: 0. """ parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.collected_params = None self.update_after_step = update_after_step self.inv_gamma = inv_gamma self.power = power self.min_value = min_value self.max_value = max_value self.decay = 0.0 self.optimization_step = 0 def get_decay(self, optimization_step: int): """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) value = 1 - (1 + step / self.inv_gamma) ** -self.power if step <= 0: return 0.0 return max(self.min_value, min(value, self.max_value)) @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. self.decay = self.get_decay(self.optimization_step) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: s_param.mul_(self.decay) s_param.add_(param.data, alpha=1 - self.decay) else: s_param.copy_(param) torch.cuda.empty_cache() def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "optimization_step": self.optimization_step, "shadow_params": self.shadow_params, "collected_params": self.collected_params, } def load_state_dict(self, state_dict: dict) -> None: r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.optimization_step = state_dict["optimization_step"] if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.shadow_params = state_dict["shadow_params"] if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") self.collected_params = state_dict["collected_params"] if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.collected_params): raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): parameters = list(parameters) original_params = [p.clone() for p in parameters] self.copy_to(parameters) try: yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data) del original_params