diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 23 |
3 files changed, 25 insertions, 6 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( | |||
148 | torch.cuda.empty_cache() | 148 | torch.cuda.empty_cache() |
149 | 149 | ||
150 | @torch.no_grad() | 150 | @torch.no_grad() |
151 | def on_sample(step): | 151 | def on_sample(cycle, step): |
152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
154 | 154 | ||
@@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( | |||
158 | unet_.to(dtype=weight_dtype) | 158 | unet_.to(dtype=weight_dtype) |
159 | text_encoder_.to(dtype=weight_dtype) | 159 | text_encoder_.to(dtype=weight_dtype) |
160 | 160 | ||
161 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 161 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
162 | 162 | ||
163 | unet_.to(dtype=orig_unet_dtype) | 163 | unet_.to(dtype=orig_unet_dtype) |
164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() |
147 | 147 | ||
148 | @torch.no_grad() | 148 | @torch.no_grad() |
149 | def on_sample(step): | 149 | def on_sample(cycle, step): |
150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
152 | 152 | ||
153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
154 | 154 | ||
155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ |
156 | 156 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,10 +104,28 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | ||
108 | if use_emb_decay: | ||
109 | params = [ | ||
110 | p | ||
111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
112 | if p.grad is not None | ||
113 | ] | ||
114 | return torch.stack(params) if len(params) != 0 else None | ||
115 | |||
116 | @torch.no_grad() | ||
107 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
108 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
110 | 120 | ||
121 | if use_emb_decay and w is not None: | ||
122 | lr = lrs["emb"] or lrs["0"] | ||
123 | lambda_ = emb_decay * lr | ||
124 | |||
125 | if lambda_ != 0: | ||
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
128 | |||
111 | def on_log(): | 129 | def on_log(): |
112 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |
113 | return {"ema_decay": ema_embeddings.decay} | 131 | return {"ema_decay": ema_embeddings.decay} |
@@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( | |||
125 | ) | 143 | ) |
126 | 144 | ||
127 | @torch.no_grad() | 145 | @torch.no_grad() |
128 | def on_sample(step): | 146 | def on_sample(cycle, step): |
129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 147 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
131 | 149 | ||
@@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( | |||
135 | unet_.to(dtype=weight_dtype) | 153 | unet_.to(dtype=weight_dtype) |
136 | text_encoder_.to(dtype=weight_dtype) | 154 | text_encoder_.to(dtype=weight_dtype) |
137 | 155 | ||
138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 156 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
139 | 157 | ||
140 | unet_.to(dtype=orig_unet_dtype) | 158 | unet_.to(dtype=orig_unet_dtype) |
141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 159 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
@@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( | |||
148 | return TrainingCallbacks( | 166 | return TrainingCallbacks( |
149 | on_train=on_train, | 167 | on_train=on_train, |
150 | on_eval=on_eval, | 168 | on_eval=on_eval, |
169 | on_before_optimize=on_before_optimize, | ||
151 | on_after_optimize=on_after_optimize, | 170 | on_after_optimize=on_after_optimize, |
152 | on_log=on_log, | 171 | on_log=on_log, |
153 | on_checkpoint=on_checkpoint, | 172 | on_checkpoint=on_checkpoint, |