from contextlib import contextmanager, nullcontext import torch from slugify import slugify from diffusers import UNet2DConditionModel from transformers import CLIPTextModel from trainer.base import TrainingStrategy, Checkpointer from training.util import EMAModel class TextualInversionCheckpointer(Checkpointer): def __init__( self, ema_embeddings: EMAModel, *args, **kwargs, ): super().__init__(*args, **kwargs) self.ema_embeddings = ema_embeddings @torch.no_grad() def checkpoint(self, step, postfix): print(f"Saving checkpoint for step {step}...") checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) if self.ema_embeddings is not None else nullcontext() with ema_context: for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) @torch.inference_mode() def save_samples(self, step): ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) if self.ema_embeddings is not None else nullcontext() with ema_context: super().save_samples(step) class TextualInversionTrainingStrategy(TrainingStrategy): def __init__( self, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], learning_rate: float, gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, emb_decay_start: float = 1e-4, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, *args, **kwargs, ): super().__init__( unet=unet, text_encoder=text_encoder, *args, **kwargs ) self.text_encoder = text_encoder self.unet = unet self.placeholder_tokens = placeholder_tokens self.placeholder_token_ids = placeholder_token_ids self.gradient_checkpointing = gradient_checkpointing self.learning_rate = learning_rate self.use_emb_decay = use_emb_decay self.emb_decay_target = emb_decay_target self.emb_decay_factor = emb_decay_factor self.emb_decay_start = emb_decay_start self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) self.ema_embeddings = None if use_ema: self.ema_embeddings = EMAModel( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) self.checkpointer = TextualInversionCheckpointer( unet=unet, text_encoder=text_encoder, ema_embeddings=self.ema_embeddings, *args, **kwargs ) @property def main_model(self): return self.text_encoder @contextmanager def on_train(self, epoch: int): try: if self.gradient_checkpointing: self.unet.train() with super().on_eval(): yield finally: pass @contextmanager def on_eval(self): try: if self.gradient_checkpointing: self.unet.eval() ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) if self.ema_embeddings is not None else nullcontext() with ema_context, super().on_eval(): yield finally: pass @torch.no_grad() def on_after_optimize(self, lr: float): if self.use_emb_decay: self.text_encoder.text_model.embeddings.normalize( self.emb_decay_target, min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) ) if self.ema_embeddings is not None: self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) def on_log(self): log = super().on_log() added = {} if self.ema_embeddings is not None: added = {"ema_decay": self.ema_embeddings.decay} return log.update(added)