from pathlib import Path import json import torch from PIL import Image def freeze_params(params): for param in params: param.requires_grad = False def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) def make_grid(images, rows, cols): w, h = images[0].size grid ='RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid class AverageMeter: def __init__(self, name=None): = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count class CheckpointerBase: def __init__( self, datamodule, output_dir: Path, placeholder_token, placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, seed ): self.datamodule = datamodule self.output_dir = output_dir self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @torch.no_grad() def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), device=pipeline.device, generator=generator, ) with torch.autocast("cuda"), torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) batches = [ batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size * self.sample_batches ] prompts = [ prompt for batch in batches for prompt in batch["prompts"] ] nprompts = [ prompt for batch in batches for prompt in batch["nprompts"] ] for i in range(self.sample_batches): prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, output_type='pil' ).images all_samples += samples del samples image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size), quality=85) del all_samples del image_grid del generator del stable_latents