From 555912a86b012382a78f1b2717c2e0fde5994a04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 11:50:16 +0100 Subject: Make embedding decay work like Adam decay --- train_ti.py | 16 ++++------------ training/strategy/ti.py | 14 +++++--------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/train_ti.py b/train_ti.py index 0891c49..fc34d27 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,7 +159,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -406,17 +406,11 @@ def parse_args(): help="Embedding decay target." ) parser.add_argument( - "--emb_decay_factor", - default=1, + "--emb_decay", + default=1e-1, type=float, help="Embedding decay factor." ) - parser.add_argument( - "--emb_decay_start", - default=0, - type=float, - help="Embedding decay start offset." - ) parser.add_argument( "--noise_timesteps", type=int, @@ -587,12 +581,10 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, - learning_rate=args.learning_rate, gradient_checkpointing=args.gradient_checkpointing, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, + emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 081180f..eb6730b 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -32,12 +32,10 @@ def textual_inversion_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], - learning_rate: float, gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 0, + emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -120,17 +118,15 @@ def textual_inversion_strategy_callbacks( yield def on_after_optimize(lr: float): - if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - @torch.no_grad() - def on_after_epoch(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, - min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) + min(1.0, emb_decay * lr) ) + if ema_embeddings is not None: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} -- cgit v1.2.3-54-g00ecf