diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/strategy/ti.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 16baa34..95128da 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks( | |||
69 | 69 | ||
70 | if use_ema: | 70 | if use_ema: |
71 | ema_embeddings = EMAModel( | 71 | ema_embeddings = EMAModel( |
72 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 72 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
73 | inv_gamma=ema_inv_gamma, | 73 | inv_gamma=ema_inv_gamma, |
74 | power=ema_power, | 74 | power=ema_power, |
75 | max_value=ema_max_decay, | 75 | max_value=ema_max_decay, |
@@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks( | |||
81 | def ema_context(): | 81 | def ema_context(): |
82 | if ema_embeddings is not None: | 82 | if ema_embeddings is not None: |
83 | return ema_embeddings.apply_temporary( | 83 | return ema_embeddings.apply_temporary( |
84 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 84 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters() |
85 | ) | 85 | ) |
86 | else: | 86 | else: |
87 | return nullcontext() | 87 | return nullcontext() |
88 | 88 | ||
89 | def on_accum_model(): | 89 | def on_accum_model(): |
90 | return text_encoder.text_model.embeddings.temp_token_embedding | 90 | return text_encoder.text_model.embeddings.token_override_embedding.params |
91 | 91 | ||
92 | @contextmanager | 92 | @contextmanager |
93 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | @torch.no_grad() | 104 | @torch.no_grad() |
105 | def on_after_optimize(zero_ids, lr: float): | 105 | def on_after_optimize(zero_ids, lr: float): |
106 | if ema_embeddings is not None: | 106 | if ema_embeddings is not None: |
107 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
108 | 108 | ||
109 | def on_log(): | 109 | def on_log(): |
110 | if ema_embeddings is not None: | 110 | if ema_embeddings is not None: |