From 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:38:49 +0200 Subject: Update --- training/strategy/lora.py | 2 +- training/strategy/ti.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cfdc504..ae85401 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -93,7 +93,7 @@ def lora_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.params + for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 720ebf3..289d6bd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -84,20 +84,20 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters() + text_encoder.text_model.embeddings.token_override_embedding.parameters() ) else: return nullcontext() @contextmanager def on_train(epoch: int): - text_encoder.text_model.embeddings.token_override_embedding.params.train() + text_encoder.train() tokenizer.train() yield @contextmanager def on_eval(): - text_encoder.text_model.embeddings.token_override_embedding.params.eval() + text_encoder.eval() tokenizer.eval() with ema_context(): @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.params + for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) if use_emb_decay and w is not None: lr = lrs["emb"] or lrs["0"] -- cgit v1.2.3-70-g09d2