From 68540b27849564994d921968a36faa9b997e626d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Dec 2022 09:17:25 +0100 Subject: Moved common training code into separate module --- training/util.py | 131 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 training/util.py (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py new file mode 100644 index 0000000..e8d22ae --- /dev/null +++ b/training/util.py @@ -0,0 +1,131 @@ +from pathlib import Path +import json + +import torch +from PIL import Image + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + +class AverageMeter: + def __init__(self, name=None): + self.name = name + self.reset() + + def reset(self): + self.sum = self.count = self.avg = 0 + + def update(self, val, n=1): + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class CheckpointerBase: + def __init__( + self, + datamodule, + output_dir: Path, + instance_identifier, + placeholder_token, + placeholder_token_id, + sample_image_size, + sample_batches, + sample_batch_size, + seed + ): + self.datamodule = datamodule + self.output_dir = output_dir + self.instance_identifier = instance_identifier + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + self.sample_image_size = sample_image_size + self.seed = seed or torch.random.seed() + self.sample_batches = sample_batches + self.sample_batch_size = sample_batch_size + + @torch.no_grad() + def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + samples_path = Path(self.output_dir).joinpath("samples") + + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() + + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), + device=pipeline.device, + generator=generator, + ) + + with torch.autocast("cuda"), torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + + for i in range(self.sample_batches): + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + image=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + ).images + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path, quality=85) + + del all_samples + del image_grid + + del generator + del stable_latents -- cgit v1.2.3-54-g00ecf