From 3e7fbb7dce321435bbbb81361debfbc499bf9231 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 22:25:30 +0100 Subject: Reverted modularization mostly --- training/modules/ti.py | 284 ------------------------------------------------- 1 file changed, 284 deletions(-) delete mode 100644 training/modules/ti.py (limited to 'training/modules/ti.py') diff --git a/training/modules/ti.py b/training/modules/ti.py deleted file mode 100644 index 2db6f88..0000000 --- a/training/modules/ti.py +++ /dev/null @@ -1,284 +0,0 @@ -from typing import Literal -from functools import partial -from contextlib import contextmanager, nullcontext - -import torch - -from slugify import slugify - -from accelerate import Accelerator -from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel - -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.tokenizer import MultiCLIPTokenizer - -from training.common import TrainingSetup, get_scheduler, train_loop, loss_step -from training.util import EMAModel, CheckpointerBase - - -class Checkpointer(CheckpointerBase): - def __init__( - self, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - ema_embeddings: EMAModel, - weight_dtype: torch.dtype, - scheduler, - placeholder_token, - placeholder_token_ids, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_token = placeholder_token - self.placeholder_token_ids = placeholder_token_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - del text_encoder - - @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - - text_encoder.to(dtype=orig_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def train_ti( - setup: TrainingSetup, - num_train_epochs: int = 100, - num_class_images: int = 0, - prior_loss_weight: float = 1.0, - use_ema: bool = False, - ema_inv_gamma: float = 1.0, - ema_power: float = 4/5, - ema_max_decay: float = .9999, - adam_beta1: float = 0.9, - adam_beta2: float = 0.999, - adam_weight_decay: float = 0, - adam_epsilon: float = 1e-08, - adam_amsgrad: bool = False, - lr_scheduler: Literal[ - "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" - ] = "one_cycle", - lr_min_lr: float = 0.04, - lr_warmup_func: Literal["linear", "cos"] = "cos", - lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", - lr_warmup_exp: int = 1, - lr_annealing_exp: int = 1, - lr_cycles: int = 1, - lr_warmup_epochs: int = 10, - emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, - sample_image_size: int = 768, - sample_batch_size: int = 1, - sample_batches: int = 1, - sample_frequency: int = 10, - sample_steps: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, -): - if use_ema: - ema_embeddings = EMAModel( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=ema_inv_gamma, - power=ema_power, - max_value=ema_max_decay, - ) - else: - ema_embeddings = None - - setup.text_encoder.requires_grad_(True) - setup.text_encoder.text_model.encoder.requires_grad_(False) - setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) - setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - - # Initialize the optimizer - optimizer = setup.optimizer_class( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=setup.learning_rate, - betas=(adam_beta1, adam_beta2), - weight_decay=adam_weight_decay, - eps=adam_epsilon, - amsgrad=adam_amsgrad, - ) - - lr_scheduler = get_scheduler( - lr_scheduler, - optimizer=optimizer, - min_lr=lr_min_lr, - warmup_func=lr_warmup_func, - annealing_func=lr_annealing_func, - warmup_exp=lr_warmup_exp, - annealing_exp=lr_annealing_exp, - cycles=lr_cycles, - train_epochs=num_train_epochs, - warmup_epochs=lr_warmup_epochs, - num_training_steps_per_epoch=len(setup.train_dataloader), - gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps - ) - - text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( - setup.text_encoder, optimizer, lr_scheduler - ) - - # Move vae and unet to device - setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) - setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) - - if use_ema: - ema_embeddings.to(setup.accelerator.device) - - setup.unet.train() - - @contextmanager - def on_train(epoch: int): - try: - setup.tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - setup.tokenizer.eval() - - ema_context = nullcontext() - if use_ema: - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) - ) - - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - loss_step_ = partial( - loss_step, - setup.vae, - setup.noise_scheduler, - setup.unet, - text_encoder, - num_class_images != 0, - prior_loss_weight, - setup.seed, - ) - - checkpointer = Checkpointer( - accelerator=setup.accelerator, - vae=setup.vae, - unet=setup.unet, - tokenizer=setup.tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - weight_dtype=setup.weight_dtype, - scheduler=setup.checkpoint_scheduler, - placeholder_token=setup.placeholder_token, - placeholder_token_ids=setup.placeholder_token_ids, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - output_dir=setup.output_dir, - seed=setup.seed, - sample_image_size=sample_image_size, - sample_batch_size=sample_batch_size, - sample_batches=sample_batches - ) - - if setup.accelerator.is_main_process: - setup.accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=setup.accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - checkpointer=checkpointer, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - sample_steps=sample_steps, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) -- cgit v1.2.3-54-g00ecf