from pathlib import Path import json from typing import Iterable import torch from PIL import Image def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid class AverageMeter: def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count class CheckpointerBase: def __init__( self, datamodule, output_dir: Path, sample_image_size, sample_batches, sample_batch_size, seed ): self.datamodule = datamodule self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @torch.inference_mode() def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) batches = [ batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size * self.sample_batches ] prompts = [ prompt for batch in batches for prompt in batch["prompts"] ] nprompts = [ prompt for batch in batches for prompt in batch["nprompts"] ] for i in range(self.sample_batches): prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, generator=gen, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, output_type='pil' ).images all_samples += samples del samples image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) del all_samples del image_grid del generator class EMAModel: """ Exponential Moving Average of models weights """ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.decay = decay self.optimization_step = 0 @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. value = (1 + self.optimization_step) / (10 + self.optimization_step) one_minus_decay = 1 - min(self.decay, value) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) torch.cuda.empty_cache() def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ]