diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/strategy/dreambooth.py | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index dc19ba3..0f64747 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks( | |||
30 | sample_output_dir: Path, | 30 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, |
32 | seed: int, | 32 | seed: int, |
33 | placeholder_tokens: list[str], | ||
34 | placeholder_token_ids: list[list[int]], | ||
33 | train_text_encoder_cycles: int, | 35 | train_text_encoder_cycles: int, |
34 | text_encoder_unfreeze_last_n_layers: int = 2, | 36 | text_encoder_unfreeze_last_n_layers: int = 2, |
37 | use_emb_decay: bool = False, | ||
38 | emb_decay_target: float = 0.4, | ||
39 | emb_decay: float = 1e-2, | ||
35 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
36 | use_ema: bool = False, | 41 | use_ema: bool = False, |
37 | ema_inv_gamma: float = 1.0, | 42 | ema_inv_gamma: float = 1.0, |
@@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks( | |||
112 | params_to_clip.append(text_encoder.parameters()) | 117 | params_to_clip.append(text_encoder.parameters()) |
113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 118 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
114 | 119 | ||
120 | if len(placeholder_tokens) != 0 and use_emb_decay: | ||
121 | params = [ | ||
122 | p | ||
123 | for p in text_encoder.text_model.embeddings.parameters() | ||
124 | if p.grad is not None | ||
125 | ] | ||
126 | return torch.stack(params) if len(params) != 0 else None | ||
127 | |||
115 | @torch.no_grad() | 128 | @torch.no_grad() |
116 | def on_after_optimize(_, lrs: dict[str, float]): | 129 | def on_after_optimize(w, lrs: dict[str, float]): |
117 | if ema_unet is not None: | 130 | if ema_unet is not None: |
118 | ema_unet.step(unet.parameters()) | 131 | ema_unet.step(unet.parameters()) |
119 | 132 | ||
133 | if w is not None and "emb" in lrs: | ||
134 | lr = lrs["emb"] | ||
135 | lambda_ = emb_decay * lr | ||
136 | |||
137 | if lambda_ != 0: | ||
138 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
139 | w[:].add_( | ||
140 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
141 | ) | ||
142 | |||
120 | def on_log(): | 143 | def on_log(): |
121 | if ema_unet is not None: | 144 | if ema_unet is not None: |
122 | return {"ema_decay": ema_unet.decay} | 145 | return {"ema_decay": ema_unet.decay} |
@@ -212,6 +235,7 @@ def dreambooth_prepare( | |||
212 | ]: | 235 | ]: |
213 | layer.requires_grad_(False) | 236 | layer.requires_grad_(False) |
214 | 237 | ||
238 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
215 | # text_encoder.text_model.embeddings.requires_grad_(False) | 239 | # text_encoder.text_model.embeddings.requires_grad_(False) |
216 | 240 | ||
217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 241 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |