From 96638bbd54ca7f91d44c938fae7275d3ecaa6add Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 14:08:49 +0100 Subject: Fixed TI normalization order --- training/functional.py | 8 ++++---- training/strategy/dreambooth.py | 11 +++++------ training/strategy/ti.py | 15 ++++++++++----- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/training/functional.py b/training/functional.py index e7c4320..b830261 100644 --- a/training/functional.py +++ b/training/functional.py @@ -38,8 +38,8 @@ class TrainingCallbacks(): on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) - on_before_optimize: Callable[[float, int], None] = const() - on_after_optimize: Callable[[float], None] = const() + on_before_optimize: Callable[[float, int], Any] = const() + on_after_optimize: Callable[[Any, float], None] = const() on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() @@ -455,13 +455,13 @@ def train_loop( local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) + before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index fcf5c0d..0290327 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -115,14 +115,13 @@ def dreambooth_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - if accelerator.sync_gradients: - params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: - params_to_clip.append(text_encoder.parameters()) - accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) + params_to_clip = [unet.parameters()] + if epoch < train_text_encoder_epochs: + params_to_clip.append(text_encoder.parameters()) + accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @torch.no_grad() - def on_after_optimize(lr: float): + def on_after_optimize(_, lr: float): if ema_unet is not None: ema_unet.step(unet.parameters()) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 09beec4..732cd74 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -115,6 +115,15 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): + if use_emb_decay: + w = text_encoder.text_model.embeddings.temp_token_embedding.weight + return torch.all(w.grad == 0, dim=1) + + @torch.no_grad() + def on_after_optimize(zero_ids, lr: float): + if ema_embeddings is not None: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + if use_emb_decay: lambda_ = emb_decay * lr @@ -123,15 +132,11 @@ def textual_inversion_strategy_callbacks( mask = torch.zeros(w.size(0), dtype=torch.bool) mask[text_encoder.text_model.embeddings.temp_token_ids] = True - mask[torch.all(w.grad == 0, dim=1)] = False + mask[zero_ids] = False norm = w[mask, :].norm(dim=-1, keepdim=True) w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) - def on_after_optimize(lr: float): - if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} -- cgit v1.2.3-54-g00ecf