From b89953bea7dfe6c92164888a66d05bc7d987ef71 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 11 Jan 2023 22:29:28 +0100 Subject: Fix --- train_ti.py | 6 +++--- training/lr.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/train_ti.py b/train_ti.py index e6c437e..951f3dd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -868,11 +868,11 @@ def main(): pass @torch.no_grad() - def on_clip(): + def on_clip(lr): embeddings = text_encoder.text_model.embeddings.temp_token_embedding pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) - lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) + lambda_ = min(1.0, 100 * lr) embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) loop = partial( @@ -991,7 +991,7 @@ def main(): accelerator.backward(loss) if accelerator.sync_gradients: - on_clip() + on_clip(lr_scheduler.get_last_lr()[0]) optimizer.step() if not accelerator.optimizer_step_was_skipped: diff --git a/training/lr.py b/training/lr.py index dfb1743..01f7f5e 100644 --- a/training/lr.py +++ b/training/lr.py @@ -26,7 +26,7 @@ class LRFinder(): val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_clip: Callable[[], None] = noop, + on_clip: Callable[[float], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): self.accelerator = accelerator @@ -95,7 +95,7 @@ class LRFinder(): self.accelerator.backward(loss) if self.accelerator.sync_gradients: - self.on_clip() + self.on_clip(lr_scheduler.get_last_lr()[0]) self.optimizer.step() lr_scheduler.step() -- cgit v1.2.3-70-g09d2