diff options
Diffstat (limited to 'training/strategy/ti.py')
| -rw-r--r-- | training/strategy/ti.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 597abd0..081180f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -88,7 +88,7 @@ def textual_inversion_strategy_callbacks( | |||
| 88 | ema_embeddings = None | 88 | ema_embeddings = None |
| 89 | 89 | ||
| 90 | def ema_context(): | 90 | def ema_context(): |
| 91 | if use_ema: | 91 | if ema_embeddings is not None: |
| 92 | return ema_embeddings.apply_temporary( | 92 | return ema_embeddings.apply_temporary( |
| 93 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 93 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
| 94 | ) | 94 | ) |
| @@ -101,7 +101,7 @@ def textual_inversion_strategy_callbacks( | |||
| 101 | def on_prepare(): | 101 | def on_prepare(): |
| 102 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 102 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) |
| 103 | 103 | ||
| 104 | if use_ema: | 104 | if ema_embeddings is not None: |
| 105 | ema_embeddings.to(accelerator.device) | 105 | ema_embeddings.to(accelerator.device) |
| 106 | 106 | ||
| 107 | if gradient_checkpointing: | 107 | if gradient_checkpointing: |
| @@ -120,7 +120,7 @@ def textual_inversion_strategy_callbacks( | |||
| 120 | yield | 120 | yield |
| 121 | 121 | ||
| 122 | def on_after_optimize(lr: float): | 122 | def on_after_optimize(lr: float): |
| 123 | if use_ema: | 123 | if ema_embeddings is not None: |
| 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
| 125 | 125 | ||
| 126 | @torch.no_grad() | 126 | @torch.no_grad() |
| @@ -132,7 +132,7 @@ def textual_inversion_strategy_callbacks( | |||
| 132 | ) | 132 | ) |
| 133 | 133 | ||
| 134 | def on_log(): | 134 | def on_log(): |
| 135 | if use_ema: | 135 | if ema_embeddings is not None: |
| 136 | return {"ema_decay": ema_embeddings.decay} | 136 | return {"ema_decay": ema_embeddings.decay} |
| 137 | return {} | 137 | return {} |
| 138 | 138 | ||
