From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 21:53:07 +0100 Subject: WIP: Modularization ("free(): invalid pointer" my ass) --- trainer/ti.py | 164 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 trainer/ti.py (limited to 'trainer/ti.py') diff --git a/trainer/ti.py b/trainer/ti.py new file mode 100644 index 0000000..15cf747 --- /dev/null +++ b/trainer/ti.py @@ -0,0 +1,164 @@ +from contextlib import contextmanager, nullcontext + +import torch + +from slugify import slugify + +from diffusers import UNet2DConditionModel +from transformers import CLIPTextModel + +from trainer.base import TrainingStrategy, Checkpointer +from training.util import EMAModel + + +class TextualInversionCheckpointer(Checkpointer): + def __init__( + self, + ema_embeddings: EMAModel, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.ema_embeddings = ema_embeddings + + @torch.no_grad() + def checkpoint(self, step, postfix): + print(f"Saving checkpoint for step {step}...") + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + @torch.inference_mode() + def save_samples(self, step): + ema_context = self.ema_embeddings.apply_temporary( + self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + super().save_samples(step) + + +class TextualInversionTrainingStrategy(TrainingStrategy): + def __init__( + self, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], + learning_rate: float, + gradient_checkpointing: bool = False, + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay_factor: float = 1, + emb_decay_start: float = 1e-4, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: int = 1, + ema_max_decay: float = 0.9999, + *args, + **kwargs, + ): + super().__init__( + unet=unet, + text_encoder=text_encoder, + *args, + **kwargs + ) + + self.text_encoder = text_encoder + self.unet = unet + + self.placeholder_tokens = placeholder_tokens + self.placeholder_token_ids = placeholder_token_ids + + self.gradient_checkpointing = gradient_checkpointing + + self.learning_rate = learning_rate + self.use_emb_decay = use_emb_decay + self.emb_decay_target = emb_decay_target + self.emb_decay_factor = emb_decay_factor + self.emb_decay_start = emb_decay_start + + self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) + + self.ema_embeddings = None + + if use_ema: + self.ema_embeddings = EMAModel( + self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + + self.checkpointer = TextualInversionCheckpointer( + unet=unet, + text_encoder=text_encoder, + ema_embeddings=self.ema_embeddings, + *args, + **kwargs + ) + + @property + def main_model(self): + return self.text_encoder + + @contextmanager + def on_train(self, epoch: int): + try: + if self.gradient_checkpointing: + self.unet.train() + + with super().on_eval(): + yield + finally: + pass + + @contextmanager + def on_eval(self): + try: + if self.gradient_checkpointing: + self.unet.eval() + + ema_context = self.ema_embeddings.apply_temporary( + self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if self.ema_embeddings is not None else nullcontext() + + with ema_context, super().on_eval(): + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(self, lr: float): + if self.use_emb_decay: + self.text_encoder.text_model.embeddings.normalize( + self.emb_decay_target, + min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) + ) + + if self.ema_embeddings is not None: + self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(self): + log = super().on_log() + added = {} + + if self.ema_embeddings is not None: + added = {"ema_decay": self.ema_embeddings.decay} + + return log.update(added) -- cgit v1.2.3-54-g00ecf