from typing import Literal from functools import partial from contextlib import contextmanager, nullcontext import torch from slugify import slugify from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.common import TrainingSetup, get_scheduler, train_loop, loss_step from training.util import EMAModel, CheckpointerBase class Checkpointer(CheckpointerBase): def __init__( self, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, text_encoder: CLIPTextModel, ema_embeddings: EMAModel, weight_dtype: torch.dtype, scheduler, placeholder_token, placeholder_token_ids, *args, **kwargs ): super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.ema_embeddings = ema_embeddings self.scheduler = scheduler self.placeholder_token = placeholder_token self.placeholder_token_ids = placeholder_token_ids @torch.no_grad() def checkpoint(self, step, postfix): print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = nullcontext() if self.ema_embeddings is not None: ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) with ema_context: for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) del text_encoder @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = nullcontext() if self.ema_embeddings is not None: ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) with ema_context: orig_dtype = text_encoder.dtype text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) text_encoder.to(dtype=orig_dtype) del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def train_ti( setup: TrainingSetup, num_train_epochs: int = 100, num_class_images: int = 0, prior_loss_weight: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: float = 4/5, ema_max_decay: float = .9999, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_weight_decay: float = 0, adam_epsilon: float = 1e-08, adam_amsgrad: bool = False, lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" ] = "one_cycle", lr_min_lr: float = 0.04, lr_warmup_func: Literal["linear", "cos"] = "cos", lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", lr_warmup_exp: int = 1, lr_annealing_exp: int = 1, lr_cycles: int = 1, lr_warmup_epochs: int = 10, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, emb_decay_start: float = 1e-4, sample_image_size: int = 768, sample_batch_size: int = 1, sample_batches: int = 1, sample_frequency: int = 10, sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, ): if use_ema: ema_embeddings = EMAModel( setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) else: ema_embeddings = None setup.text_encoder.requires_grad_(True) setup.text_encoder.text_model.encoder.requires_grad_(False) setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) # Initialize the optimizer optimizer = setup.optimizer_class( setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=setup.learning_rate, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon, amsgrad=adam_amsgrad, ) lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, min_lr=lr_min_lr, warmup_func=lr_warmup_func, annealing_func=lr_annealing_func, warmup_exp=lr_warmup_exp, annealing_exp=lr_annealing_exp, cycles=lr_cycles, train_epochs=num_train_epochs, warmup_epochs=lr_warmup_epochs, num_training_steps_per_epoch=len(setup.train_dataloader), gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps ) text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( setup.text_encoder, optimizer, lr_scheduler ) # Move vae and unet to device setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) if use_ema: ema_embeddings.to(setup.accelerator.device) setup.unet.train() @contextmanager def on_train(epoch: int): try: setup.tokenizer.train() yield finally: pass @contextmanager def on_eval(): try: setup.tokenizer.eval() ema_context = nullcontext() if use_ema: ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) with ema_context: yield finally: pass @torch.no_grad() def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) ) if use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} return {} loss_step_ = partial( loss_step, setup.vae, setup.noise_scheduler, setup.unet, text_encoder, num_class_images != 0, prior_loss_weight, setup.seed, ) checkpointer = Checkpointer( accelerator=setup.accelerator, vae=setup.vae, unet=setup.unet, tokenizer=setup.tokenizer, text_encoder=text_encoder, ema_embeddings=ema_embeddings, weight_dtype=setup.weight_dtype, scheduler=setup.checkpoint_scheduler, placeholder_token=setup.placeholder_token, placeholder_token_ids=setup.placeholder_token_ids, train_dataloader=setup.train_dataloader, val_dataloader=setup.val_dataloader, output_dir=setup.output_dir, seed=setup.seed, sample_image_size=sample_image_size, sample_batch_size=sample_batch_size, sample_batches=sample_batches ) if setup.accelerator.is_main_process: setup.accelerator.init_trackers("textual_inversion") train_loop( accelerator=setup.accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, model=text_encoder, checkpointer=checkpointer, train_dataloader=setup.train_dataloader, val_dataloader=setup.val_dataloader, loss_step=loss_step_, sample_frequency=sample_frequency, sample_steps=sample_steps, checkpoint_frequency=checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=num_train_epochs, on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, on_eval=on_eval )