From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 21:53:07 +0100 Subject: WIP: Modularization ("free(): invalid pointer" my ass) --- training/util.py | 214 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 118 insertions(+), 96 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 781cf04..a292edd 100644 --- a/training/util.py +++ b/training/util.py @@ -1,12 +1,40 @@ from pathlib import Path import json import copy -import itertools -from typing import Iterable, Optional +from typing import Iterable, Union from contextlib import contextmanager import torch -from PIL import Image + +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler + +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.tokenizer import MultiCLIPTokenizer +from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings + + +class TrainingStrategy(): + @property + def main_model(self) -> torch.nn.Module: + ... + + @contextmanager + def on_train(self, epoch: int): + yield + + @contextmanager + def on_eval(self): + yield + + def on_before_optimize(self, epoch: int): + ... + + def on_after_optimize(self, lr: float): + ... + + def on_log(): + return {} def save_args(basepath: Path, args, extra={}): @@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}): json.dump(info, f, indent=4) -def make_grid(images, rows, cols): - w, h = images[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - for i, image in enumerate(images): - grid.paste(image, box=(i % cols*w, i//cols*h)) - return grid +def generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + scheduler, + data_train, + sample_batch_size, + sample_image_size, + sample_steps +): + missing_data = [item for item in data_train if not item.class_image_path.exists()] + + if len(missing_data) == 0: + return + + batched_data = [ + missing_data[i:i+sample_batch_size] + for i in range(0, len(missing_data), sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.cprompt for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=sample_image_size, + width=sample_image_size, + num_inference_steps=sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def get_models(pretrained_model_name_or_path: str): + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler') + + embeddings = patch_managed_embeddings(text_encoder) + + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings + + +def add_placeholder_tokens( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + placeholder_tokens: list[str], + initializer_tokens: list[str], + num_vectors: Union[list[int], int] +): + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) + for token in initializer_tokens + ] + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) + + embeddings.resize(len(tokenizer)) + + for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(placeholder_token_id, initializer_token_id) + + return placeholder_token_ids, initializer_token_ids class AverageMeter: @@ -38,93 +147,6 @@ class AverageMeter: self.avg = self.sum / self.count -class CheckpointerBase: - def __init__( - self, - train_dataloader, - val_dataloader, - output_dir: Path, - sample_steps: int = 20, - sample_guidance_scale: float = 7.5, - sample_image_size: int = 768, - sample_batches: int = 1, - sample_batch_size: int = 1, - seed: Optional[int] = None - ): - self.train_dataloader = train_dataloader - self.val_dataloader = val_dataloader - self.output_dir = output_dir - self.sample_image_size = sample_image_size - self.sample_steps = sample_steps - self.sample_guidance_scale = sample_guidance_scale - self.sample_batches = sample_batches - self.sample_batch_size = sample_batch_size - self.seed = seed if seed is not None else torch.random.seed() - - @torch.no_grad() - def checkpoint(self, step: int, postfix: str): - pass - - @torch.inference_mode() - def save_samples(self, pipeline, step: int): - samples_path = Path(self.output_dir).joinpath("samples") - - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - - grid_cols = min(self.sample_batch_size, 4) - grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - - for pool, data, gen in [ - ("stable", self.val_dataloader, generator), - ("val", self.val_dataloader, None), - ("train", self.train_dataloader, None) - ]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) - prompt_ids = [ - prompt - for batch in batches - for prompt in batch["prompt_ids"] - ] - nprompt_ids = [ - prompt - for batch in batches - for prompt in batch["nprompt_ids"] - ] - - for i in range(self.sample_batches): - start = i * self.sample_batch_size - end = (i + 1) * self.sample_batch_size - prompt = prompt_ids[start:end] - nprompt = nprompt_ids[start:end] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - generator=gen, - guidance_scale=self.sample_guidance_scale, - num_inference_steps=self.sample_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, grid_rows, grid_cols) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid - - del generator - - # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ -- cgit v1.2.3-54-g00ecf