From c96073646bbb638d7d78fdd7d9fdeed08d1454b5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 16:30:36 +0200 Subject: Experimental: TI via LoRA --- training/strategy/ti.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..19b8d25 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, - use_emb_decay: bool = False, - emb_decay_target: float = 0.4, - emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() + text_encoder.text_model.embeddings.overlay.parameters() ) else: return nullcontext() def on_accum_model(): - return text_encoder.text_model.embeddings.temp_token_embedding + return text_encoder.text_model.embeddings.overlay @contextmanager def on_train(epoch: int): @@ -105,28 +102,10 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): - if use_emb_decay: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - return torch.all(w.grad == 0, dim=1) - @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - if use_emb_decay: - lambda_ = emb_decay * lr - - if lambda_ != 0: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - - mask = torch.ones(w.shape[0], dtype=torch.bool) - mask[zero_ids] = False - - norm = w[mask, :].norm(dim=-1, keepdim=True) - w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) def on_log(): if ema_embeddings is not None: @@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, - on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2