From f00877a13bce50b02cfc3790f2d18a325e9ff95b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:42:44 +0100 Subject: Update --- trainer/ti.py | 164 ---------------------------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 trainer/ti.py (limited to 'trainer/ti.py') diff --git a/trainer/ti.py b/trainer/ti.py deleted file mode 100644 index 388acd3..0000000 --- a/trainer/ti.py +++ /dev/null @@ -1,164 +0,0 @@ -from contextlib import contextmanager, nullcontext - -import torch - -from slugify import slugify - -from diffusers import UNet2DConditionModel -from transformers import CLIPTextModel - -from trainer.base import TrainingStrategy, Checkpointer -from training.util import EMAModel - - -class TextualInversionCheckpointer(Checkpointer): - def __init__( - self, - ema_embeddings: EMAModel, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.ema_embeddings = ema_embeddings - - @torch.no_grad() - def checkpoint(self, step, postfix): - print(f"Saving checkpoint for step {step}...") - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - @torch.no_grad() - def save_samples(self, step): - ema_context = self.ema_embeddings.apply_temporary( - self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - super().save_samples(step) - - -class TextualInversionTrainingStrategy(TrainingStrategy): - def __init__( - self, - unet: UNet2DConditionModel, - text_encoder: CLIPTextModel, - placeholder_tokens: list[str], - placeholder_token_ids: list[list[int]], - learning_rate: float, - gradient_checkpointing: bool = False, - use_emb_decay: bool = False, - emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, - use_ema: bool = False, - ema_inv_gamma: float = 1.0, - ema_power: int = 1, - ema_max_decay: float = 0.9999, - *args, - **kwargs, - ): - super().__init__( - unet=unet, - text_encoder=text_encoder, - *args, - **kwargs - ) - - self.text_encoder = text_encoder - self.unet = unet - - self.placeholder_tokens = placeholder_tokens - self.placeholder_token_ids = placeholder_token_ids - - self.gradient_checkpointing = gradient_checkpointing - - self.learning_rate = learning_rate - self.use_emb_decay = use_emb_decay - self.emb_decay_target = emb_decay_target - self.emb_decay_factor = emb_decay_factor - self.emb_decay_start = emb_decay_start - - self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - - self.ema_embeddings = None - - if use_ema: - self.ema_embeddings = EMAModel( - self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=ema_inv_gamma, - power=ema_power, - max_value=ema_max_decay, - ) - - self.checkpointer = TextualInversionCheckpointer( - unet=unet, - text_encoder=text_encoder, - ema_embeddings=self.ema_embeddings, - *args, - **kwargs - ) - - @property - def main_model(self): - return self.text_encoder - - @contextmanager - def on_train(self, epoch: int): - try: - if self.gradient_checkpointing: - self.unet.train() - - with super().on_eval(): - yield - finally: - pass - - @contextmanager - def on_eval(self): - try: - if self.gradient_checkpointing: - self.unet.eval() - - ema_context = self.ema_embeddings.apply_temporary( - self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if self.ema_embeddings is not None else nullcontext() - - with ema_context, super().on_eval(): - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(self, lr: float): - if self.use_emb_decay: - self.text_encoder.text_model.embeddings.normalize( - self.emb_decay_target, - min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) - ) - - if self.ema_embeddings is not None: - self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(self): - log = super().on_log() - added = {} - - if self.ema_embeddings is not None: - added = {"ema_decay": self.ema_embeddings.decay} - - return log.update(added) -- cgit v1.2.3-54-g00ecf