From 30b557c8e1f03b4748ac3efca599ff51d66561cb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Apr 2023 07:30:43 +0200 Subject: TI: Bring back old embedding decay --- training/functional.py | 4 ++-- training/strategy/ti.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 1d8e2ee..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_alpha) + embeddings = patch_managed_embeddings(text_encoder) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 95128da..9df160a 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_after_optimize(zero_ids, lr: float): + def on_before_optimize(lr: float, epoch: int): + if use_emb_decay: + return torch.stack([ + p + for p in text_encoder.text_model.embeddings.token_override_embedding.params + if p.grad is not None + ]) + + @torch.no_grad() + def on_after_optimize(w, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) + if use_emb_decay: + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-54-g00ecf