diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bc26ee6..d813b49 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -88,7 +88,7 @@ def dreambooth_strategy_callbacks( | |||
| 88 | ema_unet = None | 88 | ema_unet = None |
| 89 | 89 | ||
| 90 | def ema_context(): | 90 | def ema_context(): |
| 91 | if use_ema: | 91 | if ema_unet is not None: |
| 92 | return ema_unet.apply_temporary(unet.parameters()) | 92 | return ema_unet.apply_temporary(unet.parameters()) |
| 93 | else: | 93 | else: |
| 94 | return nullcontext() | 94 | return nullcontext() |
| @@ -102,7 +102,7 @@ def dreambooth_strategy_callbacks( | |||
| 102 | text_encoder.text_model.embeddings.persist() | 102 | text_encoder.text_model.embeddings.persist() |
| 103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | 103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) |
| 104 | 104 | ||
| 105 | if use_ema: | 105 | if ema_unet is not None: |
| 106 | ema_unet.to(accelerator.device) | 106 | ema_unet.to(accelerator.device) |
| 107 | 107 | ||
| 108 | @contextmanager | 108 | @contextmanager |
| @@ -134,11 +134,11 @@ def dreambooth_strategy_callbacks( | |||
| 134 | 134 | ||
| 135 | @torch.no_grad() | 135 | @torch.no_grad() |
| 136 | def on_after_optimize(lr: float): | 136 | def on_after_optimize(lr: float): |
| 137 | if use_ema: | 137 | if ema_unet is not None: |
| 138 | ema_unet.step(unet.parameters()) | 138 | ema_unet.step(unet.parameters()) |
| 139 | 139 | ||
| 140 | def on_log(): | 140 | def on_log(): |
| 141 | if use_ema: | 141 | if ema_unet is not None: |
| 142 | return {"ema_decay": ema_unet.decay} | 142 | return {"ema_decay": ema_unet.decay} |
| 143 | return {} | 143 | return {} |
| 144 | 144 | ||
