From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/strategy/ti.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if use_emb_decay: params = [ p @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + ema_embeddings.step( + text_encoder.text_model.embeddings.token_embedding.parameters() + ) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) def on_log(): if ema_embeddings is not None: @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( print(f"Saving checkpoint for step {step}...") with ema_context(): - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", ) @torch.no_grad() @@ -183,7 +187,7 @@ def textual_inversion_prepare( val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, - **kwargs + **kwargs, ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -191,8 +195,15 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) -- cgit v1.2.3-54-g00ecf