From 555912a86b012382a78f1b2717c2e0fde5994a04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 11:50:16 +0100 Subject: Make embedding decay work like Adam decay --- training/strategy/ti.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'training') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 081180f..eb6730b 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -32,12 +32,10 @@ def textual_inversion_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], - learning_rate: float, gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 0, + emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -120,17 +118,15 @@ def textual_inversion_strategy_callbacks( yield def on_after_optimize(lr: float): - if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - @torch.no_grad() - def on_after_epoch(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, - min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) + min(1.0, emb_decay * lr) ) + if ema_embeddings is not None: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} -- cgit v1.2.3-54-g00ecf