diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/strategy/ti.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 16baa34..95128da 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks( | |||
| 69 | 69 | ||
| 70 | if use_ema: | 70 | if use_ema: |
| 71 | ema_embeddings = EMAModel( | 71 | ema_embeddings = EMAModel( |
| 72 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 72 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 73 | inv_gamma=ema_inv_gamma, | 73 | inv_gamma=ema_inv_gamma, |
| 74 | power=ema_power, | 74 | power=ema_power, |
| 75 | max_value=ema_max_decay, | 75 | max_value=ema_max_decay, |
| @@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks( | |||
| 81 | def ema_context(): | 81 | def ema_context(): |
| 82 | if ema_embeddings is not None: | 82 | if ema_embeddings is not None: |
| 83 | return ema_embeddings.apply_temporary( | 83 | return ema_embeddings.apply_temporary( |
| 84 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 84 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters() |
| 85 | ) | 85 | ) |
| 86 | else: | 86 | else: |
| 87 | return nullcontext() | 87 | return nullcontext() |
| 88 | 88 | ||
| 89 | def on_accum_model(): | 89 | def on_accum_model(): |
| 90 | return text_encoder.text_model.embeddings.temp_token_embedding | 90 | return text_encoder.text_model.embeddings.token_override_embedding.params |
| 91 | 91 | ||
| 92 | @contextmanager | 92 | @contextmanager |
| 93 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | @torch.no_grad() | 104 | @torch.no_grad() |
| 105 | def on_after_optimize(zero_ids, lr: float): | 105 | def on_after_optimize(zero_ids, lr: float): |
| 106 | if ema_embeddings is not None: | 106 | if ema_embeddings is not None: |
| 107 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
| 108 | 108 | ||
| 109 | def on_log(): | 109 | def on_log(): |
| 110 | if ema_embeddings is not None: | 110 | if ema_embeddings is not None: |
