From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- training/modules/ti.py | 284 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 training/modules/ti.py (limited to 'training/modules/ti.py') diff --git a/training/modules/ti.py b/training/modules/ti.py new file mode 100644 index 0000000..2db6f88 --- /dev/null +++ b/training/modules/ti.py @@ -0,0 +1,284 @@ +from typing import Literal +from functools import partial +from contextlib import contextmanager, nullcontext + +import torch + +from slugify import slugify + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel + +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.tokenizer import MultiCLIPTokenizer + +from training.common import TrainingSetup, get_scheduler, train_loop, loss_step +from training.util import EMAModel, CheckpointerBase + + +class Checkpointer(CheckpointerBase): + def __init__( + self, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, + weight_dtype: torch.dtype, + scheduler, + placeholder_token, + placeholder_token_ids, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.weight_dtype = weight_dtype + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings + self.scheduler = scheduler + self.placeholder_token = placeholder_token + self.placeholder_token_ids = placeholder_token_ids + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + del text_encoder + + @torch.no_grad() + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + + text_encoder.to(dtype=orig_dtype) + + del text_encoder + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def train_ti( + setup: TrainingSetup, + num_train_epochs: int = 100, + num_class_images: int = 0, + prior_loss_weight: float = 1.0, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: float = 4/5, + ema_max_decay: float = .9999, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 0, + adam_epsilon: float = 1e-08, + adam_amsgrad: bool = False, + lr_scheduler: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" + ] = "one_cycle", + lr_min_lr: float = 0.04, + lr_warmup_func: Literal["linear", "cos"] = "cos", + lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", + lr_warmup_exp: int = 1, + lr_annealing_exp: int = 1, + lr_cycles: int = 1, + lr_warmup_epochs: int = 10, + emb_decay_target: float = 0.4, + emb_decay_factor: float = 1, + emb_decay_start: float = 1e-4, + sample_image_size: int = 768, + sample_batch_size: int = 1, + sample_batches: int = 1, + sample_frequency: int = 10, + sample_steps: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, +): + if use_ema: + ema_embeddings = EMAModel( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + else: + ema_embeddings = None + + setup.text_encoder.requires_grad_(True) + setup.text_encoder.text_model.encoder.requires_grad_(False) + setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) + setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) + + # Initialize the optimizer + optimizer = setup.optimizer_class( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=setup.learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + amsgrad=adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + min_lr=lr_min_lr, + warmup_func=lr_warmup_func, + annealing_func=lr_annealing_func, + warmup_exp=lr_warmup_exp, + annealing_exp=lr_annealing_exp, + cycles=lr_cycles, + train_epochs=num_train_epochs, + warmup_epochs=lr_warmup_epochs, + num_training_steps_per_epoch=len(setup.train_dataloader), + gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps + ) + + text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( + setup.text_encoder, optimizer, lr_scheduler + ) + + # Move vae and unet to device + setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) + setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) + + if use_ema: + ema_embeddings.to(setup.accelerator.device) + + setup.unet.train() + + @contextmanager + def on_train(epoch: int): + try: + setup.tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + setup.tokenizer.eval() + + ema_context = nullcontext() + if use_ema: + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) + ) + + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + loss_step_ = partial( + loss_step, + setup.vae, + setup.noise_scheduler, + setup.unet, + text_encoder, + num_class_images != 0, + prior_loss_weight, + setup.seed, + ) + + checkpointer = Checkpointer( + accelerator=setup.accelerator, + vae=setup.vae, + unet=setup.unet, + tokenizer=setup.tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + weight_dtype=setup.weight_dtype, + scheduler=setup.checkpoint_scheduler, + placeholder_token=setup.placeholder_token, + placeholder_token_ids=setup.placeholder_token_ids, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + output_dir=setup.output_dir, + seed=setup.seed, + sample_image_size=sample_image_size, + sample_batch_size=sample_batch_size, + sample_batches=sample_batches + ) + + if setup.accelerator.is_main_process: + setup.accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=setup.accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + sample_steps=sample_steps, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) -- cgit v1.2.3-54-g00ecf