from pathlib import Path import json import copy from typing import Iterable, Union from contextlib import contextmanager import torch from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings class TrainingStrategy(): @property def main_model(self) -> torch.nn.Module: ... @contextmanager def on_train(self, epoch: int): yield @contextmanager def on_eval(self): yield def on_before_optimize(self, epoch: int): ... def on_after_optimize(self, lr: float): ... def on_log(): return {} def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) def generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, scheduler, data_train, sample_batch_size, sample_image_size, sample_steps ): missing_data = [item for item in data_train if not item.class_image_path.exists()] if len(missing_data) == 0: return batched_data = [ missing_data[i:i+sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') embeddings = patch_managed_embeddings(text_encoder) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[list[int], int] ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id) return placeholder_token_ids, initializer_token_ids class AverageMeter: def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], update_after_step: int = 0, inv_gamma: float = 1.0, power: float = 2 / 3, min_value: float = 0.0, max_value: float = 0.9999, ): """ @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. min_value (float): The minimum EMA decay rate. Default: 0. """ parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.collected_params = None self.update_after_step = update_after_step self.inv_gamma = inv_gamma self.power = power self.min_value = min_value self.max_value = max_value self.decay = 0.0 self.optimization_step = 0 def get_decay(self, optimization_step: int): """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) value = 1 - (1 + step / self.inv_gamma) ** -self.power if step <= 0: return 0.0 return max(self.min_value, min(value, self.max_value)) @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. self.decay = self.get_decay(self.optimization_step) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: s_param.mul_(self.decay) s_param.add_(param.data, alpha=1 - self.decay) else: s_param.copy_(param) torch.cuda.empty_cache() def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "optimization_step": self.optimization_step, "shadow_params": self.shadow_params, "collected_params": self.collected_params, } def load_state_dict(self, state_dict: dict) -> None: r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.optimization_step = state_dict["optimization_step"] if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.shadow_params = state_dict["shadow_params"] if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") self.collected_params = state_dict["collected_params"] if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.collected_params): raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): try: parameters = list(parameters) original_params = [p.clone() for p in parameters] self.copy_to(parameters) yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data)