summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index dc19ba3..0f64747 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks(
30 sample_output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path, 31 checkpoint_output_dir: Path,
32 seed: int, 32 seed: int,
33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]],
33 train_text_encoder_cycles: int, 35 train_text_encoder_cycles: int,
34 text_encoder_unfreeze_last_n_layers: int = 2, 36 text_encoder_unfreeze_last_n_layers: int = 2,
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2,
35 max_grad_norm: float = 1.0, 40 max_grad_norm: float = 1.0,
36 use_ema: bool = False, 41 use_ema: bool = False,
37 ema_inv_gamma: float = 1.0, 42 ema_inv_gamma: float = 1.0,
@@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks(
112 params_to_clip.append(text_encoder.parameters()) 117 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 118 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 119
120 if len(placeholder_tokens) != 0 and use_emb_decay:
121 params = [
122 p
123 for p in text_encoder.text_model.embeddings.parameters()
124 if p.grad is not None
125 ]
126 return torch.stack(params) if len(params) != 0 else None
127
115 @torch.no_grad() 128 @torch.no_grad()
116 def on_after_optimize(_, lrs: dict[str, float]): 129 def on_after_optimize(w, lrs: dict[str, float]):
117 if ema_unet is not None: 130 if ema_unet is not None:
118 ema_unet.step(unet.parameters()) 131 ema_unet.step(unet.parameters())
119 132
133 if w is not None and "emb" in lrs:
134 lr = lrs["emb"]
135 lambda_ = emb_decay * lr
136
137 if lambda_ != 0:
138 norm = w[:, :].norm(dim=-1, keepdim=True)
139 w[:].add_(
140 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
141 )
142
120 def on_log(): 143 def on_log():
121 if ema_unet is not None: 144 if ema_unet is not None:
122 return {"ema_decay": ema_unet.decay} 145 return {"ema_decay": ema_unet.decay}
@@ -212,6 +235,7 @@ def dreambooth_prepare(
212 ]: 235 ]:
213 layer.requires_grad_(False) 236 layer.requires_grad_(False)
214 237
238 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
215 # text_encoder.text_model.embeddings.requires_grad_(False) 239 # text_encoder.text_model.embeddings.requires_grad_(False)
216 240
217 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 241 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler