From 950f1f6bcbb1a767170cea590b828d8e3cdae882 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Jun 2023 06:48:38 +0200 Subject: Update --- training/strategy/dreambooth.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index dc19ba3..0f64747 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -30,8 +30,13 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], train_text_encoder_cycles: int, text_encoder_unfreeze_last_n_layers: int = 2, + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -112,11 +117,29 @@ def dreambooth_strategy_callbacks( params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) + if len(placeholder_tokens) != 0 and use_emb_decay: + params = [ + p + for p in text_encoder.text_model.embeddings.parameters() + if p.grad is not None + ] + return torch.stack(params) if len(params) != 0 else None + @torch.no_grad() - def on_after_optimize(_, lrs: dict[str, float]): + def on_after_optimize(w, lrs: dict[str, float]): if ema_unet is not None: ema_unet.step(unet.parameters()) + if w is not None and "emb" in lrs: + lr = lrs["emb"] + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) + def on_log(): if ema_unet is not None: return {"ema_decay": ema_unet.decay} @@ -212,6 +235,7 @@ def dreambooth_prepare( ]: layer.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2