From 8e9d62225db11913bf7ef67221fc3508d7fe1149 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 16:39:33 +0100 Subject: Update --- training/strategy/dreambooth.py | 5 ++--- training/strategy/ti.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d813b49..f57e736 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks( def on_prepare(): unet.requires_grad_(True) text_encoder.requires_grad_(True) - text_encoder.text_model.embeddings.persist() - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.requires_grad_(False) if ema_unet is not None: ema_unet.to(accelerator.device) @@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(lr: float, epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ba78b98..e922954 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - def on_after_optimize(lr: float): + @torch.no_grad() + def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - with torch.no_grad(): - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, emb_decay * lr) - ) + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, emb_decay * lr) + ) + def on_after_optimize(lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks( on_model=on_model, on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2