From 34648b763fa60e3161a5b5f1243ed1b5c3b0188e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 10:12:04 +0100 Subject: Added functional TI strategy --- training/functional.py | 118 ++++++++++++++++++++++++++++++++++ training/strategy/ti.py | 164 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 training/strategy/ti.py (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 1f2ca6d..e54c9c8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -2,6 +2,8 @@ import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional from functools import partial +from pathlib import Path +import itertools import torch import torch.nn.functional as F @@ -26,6 +28,14 @@ def const(result=None): return fn +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') @@ -40,6 +50,107 @@ def get_models(pretrained_model_name_or_path: str): return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings +def save_samples( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: MultiCLIPTokenizer, + vae: AutoencoderKL, + sample_scheduler: DPMSolverMultistepScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + output_dir: Path, + seed: int, + step: int, + batch_size: int = 1, + num_batches: int = 1, + num_steps: int = 20, + guidance_scale: float = 7.5, + image_size: Optional[int] = None, +): + print(f"Saving samples for step {step}...") + + samples_path = output_dir.joinpath("samples") + + grid_cols = min(batch_size, 4) + grid_rows = (num_batches * batch_size) // grid_cols + + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype + + unet.to(dtype=dtype) + text_encoder.to(dtype=dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=sample_scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + generator = torch.Generator(device=accelerator.device).manual_seed(seed) + + for pool, data, gen in [ + ("stable", val_dataloader, generator), + ("val", val_dataloader, None), + ("train", train_dataloader, None) + ]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) + prompt_ids = [ + prompt + for batch in batches + for prompt in batch["prompt_ids"] + ] + nprompt_ids = [ + prompt + for batch in batches + for prompt in batch["nprompt_ids"] + ] + + for i in range(num_batches): + start = i * batch_size + end = (i + 1) * batch_size + prompt = prompt_ids[start:end] + nprompt = nprompt_ids[start:end] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=image_size, + width=image_size, + generator=gen, + guidance_scale=guidance_scale, + num_inference_steps=num_steps, + output_type='pil' + ).images + + all_samples += samples + + image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid.save(file_path, quality=85) + + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) + + del unet + del text_encoder + del generator + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def generate_class_images( accelerator: Accelerator, text_encoder: CLIPTextModel, @@ -109,6 +220,10 @@ def get_models(pretrained_model_name_or_path: str): embeddings = patch_managed_embeddings(text_encoder) + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings @@ -427,6 +542,9 @@ def train( seed, ) + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + train_loop( accelerator=accelerator, optimizer=optimizer, diff --git a/training/strategy/ti.py b/training/strategy/ti.py new file mode 100644 index 0000000..83dc566 --- /dev/null +++ b/training/strategy/ti.py @@ -0,0 +1,164 @@ +from contextlib import nullcontext +from typing import Optional +from functools import partial +from contextlib import contextmanager, nullcontext +from pathlib import Path + +import torch +from torch.utils.data import DataLoader + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler + +from slugify import slugify + +from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import EMAModel +from training.functional import save_samples + + +def textual_inversion_strategy( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: MultiCLIPTokenizer, + vae: AutoencoderKL, + sample_scheduler: DPMSolverMultistepScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + output_dir: Path, + seed: int, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], + learning_rate: float, + gradient_checkpointing: bool = False, + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay_factor: float = 1, + emb_decay_start: float = 1e-4, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: int = 1, + ema_max_decay: float = 0.9999, + sample_batch_size: int = 1, + sample_num_batches: int = 1, + sample_num_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: Optional[int] = None, +): + save_samples_ = partial( + save_samples, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=dtype, + output_dir=output_dir, + seed=seed, + batch_size=sample_batch_size, + num_batches=sample_num_batches, + num_steps=sample_num_steps, + guidance_scale=sample_guidance_scale, + image_size=sample_image_size, + ) + + if use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + else: + ema_embeddings = None + + def on_prepare(): + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) + + if use_ema: + ema_embeddings.to(accelerator.device) + + if gradient_checkpointing: + unet.train() + + @contextmanager + def on_train(epoch: int): + try: + tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + tokenizer.eval() + + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + if use_emb_decay: + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) + ) + + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + @torch.no_grad() + def on_checkpoint(step, postfix): + print(f"Saving checkpoint for step {step}...") + + checkpoints_path = output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = accelerator.unwrap_model(text_encoder) + + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if ema_embeddings is not None else nullcontext() + + with ema_context: + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + @torch.no_grad() + def on_sample(step): + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if ema_embeddings is not None else nullcontext() + + with ema_context: + save_samples_(step=step) + + return { + "on_prepare": on_prepare, + "on_train": on_train, + "on_eval": on_eval, + "on_after_optimize": on_after_optimize, + "on_log": on_log, + "on_checkpoint": on_checkpoint, + "on_sample": on_sample, + } -- cgit v1.2.3-54-g00ecf