From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- training/strategy/ti.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.token_override_embedding.parameters() + text_encoder.text_model.embeddings.token_embedding.parameters() ) else: return nullcontext() @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.token_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -203,7 +203,6 @@ def textual_inversion_prepare( text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-54-g00ecf