From 16b92605a59d59c65789c89b54bb97da51908056 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 09:09:50 +0100 Subject: Embedding normalization: Ignore tensors with grad = 0 --- training/functional.py | 7 +++++-- training/strategy/ti.py | 15 +++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py @@ -362,6 +362,7 @@ def train_loop( loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, + milestone_checkpoints: bool = True, global_step_offset: int = 0, num_epochs: int = 100, callbacks: TrainingCallbacks = TrainingCallbacks(), @@ -514,7 +515,7 @@ def train_loop( accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg.item() > best_acc_val: + if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() @@ -527,7 +528,7 @@ def train_loop( accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: - if avg_acc.avg.item() > best_acc: + if avg_acc.avg.item() > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() @@ -572,6 +573,7 @@ def train( num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, + milestone_checkpoints: bool = True, global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, @@ -626,6 +628,7 @@ def train( loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, + milestone_checkpoints=milestone_checkpoints, global_step_offset=global_step_offset, num_epochs=num_train_epochs, callbacks=callbacks, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, emb_decay * lr) - ) + lambda_ = emb_decay * lr + + if lambda_ != 0: + w = text_encoder.text_model.embeddings.temp_token_embedding.weight + + mask = torch.zeros(w.size(0), dtype=torch.bool) + mask[text_encoder.text_model.embeddings.temp_token_ids] = True + mask[torch.all(w.grad == 0, dim=1)] = False + + norm = w[mask, :].norm(dim=-1, keepdim=True) + w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) def on_after_optimize(lr: float): if ema_embeddings is not None: -- cgit v1.2.3-70-g09d2