summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index bc26ee6..d813b49 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -88,7 +88,7 @@ def dreambooth_strategy_callbacks(
88 ema_unet = None 88 ema_unet = None
89 89
90 def ema_context(): 90 def ema_context():
91 if use_ema: 91 if ema_unet is not None:
92 return ema_unet.apply_temporary(unet.parameters()) 92 return ema_unet.apply_temporary(unet.parameters())
93 else: 93 else:
94 return nullcontext() 94 return nullcontext()
@@ -102,7 +102,7 @@ def dreambooth_strategy_callbacks(
102 text_encoder.text_model.embeddings.persist() 102 text_encoder.text_model.embeddings.persist()
103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) 103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
104 104
105 if use_ema: 105 if ema_unet is not None:
106 ema_unet.to(accelerator.device) 106 ema_unet.to(accelerator.device)
107 107
108 @contextmanager 108 @contextmanager
@@ -134,11 +134,11 @@ def dreambooth_strategy_callbacks(
134 134
135 @torch.no_grad() 135 @torch.no_grad()
136 def on_after_optimize(lr: float): 136 def on_after_optimize(lr: float):
137 if use_ema: 137 if ema_unet is not None:
138 ema_unet.step(unet.parameters()) 138 ema_unet.step(unet.parameters())
139 139
140 def on_log(): 140 def on_log():
141 if use_ema: 141 if ema_unet is not None:
142 return {"ema_decay": ema_unet.decay} 142 return {"ema_decay": ema_unet.decay}
143 return {} 143 return {}
144 144