from pathlib import Path import json import copy from typing import Iterable from contextlib import contextmanager import torch from PIL import Image def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid class AverageMeter: def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count class CheckpointerBase: def __init__( self, datamodule, output_dir: Path, sample_image_size, sample_batches, sample_batch_size, seed ): self.datamodule = datamodule self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @torch.inference_mode() def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) batches = [ batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size * self.sample_batches ] prompts = [ prompt for batch in batches for prompt in batch["prompts"] ] nprompts = [ prompt for batch in batches for prompt in batch["nprompts"] ] for i in range(self.sample_batches): prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, generator=gen, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, output_type='pil' ).images all_samples += samples del samples image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) del all_samples del image_grid del generator # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999, ): """ @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. min_value (float): The minimum EMA decay rate. Default: 0. """ parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.collected_params = None self.update_after_step = update_after_step self.inv_gamma = inv_gamma self.power = power self.min_value = min_value self.max_value = max_value self.decay = 0.0 self.optimization_step = 0 def get_decay(self, optimization_step): """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) value = 1 - (1 + step / self.inv_gamma) ** -self.power if step <= 0: return 0.0 return max(self.min_value, min(value, self.max_value)) @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. self.decay = self.get_decay(self.optimization_step) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: s_param.mul_(self.decay) s_param.add_(param.data, alpha=1 - self.decay) else: s_param.copy_(param) torch.cuda.empty_cache() def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "optimization_step": self.optimization_step, "shadow_params": self.shadow_params, "collected_params": self.collected_params, } def load_state_dict(self, state_dict: dict) -> None: r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.optimization_step = state_dict["optimization_step"] if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.shadow_params = state_dict["shadow_params"] if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") self.collected_params = state_dict["collected_params"] if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.collected_params): raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") @contextmanager def apply_temporary(self, parameters): try: parameters = list(parameters) original_params = [p.clone() for p in parameters] self.copy_to(parameters) yield finally: for s_param, param in zip(original_params, parameters): param.data.copy_(s_param.data)