summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py6
-rw-r--r--training/lr.py4
2 files changed, 5 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index e6c437e..951f3dd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -868,11 +868,11 @@ def main():
868 pass 868 pass
869 869
870 @torch.no_grad() 870 @torch.no_grad()
871 def on_clip(): 871 def on_clip(lr):
872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding 872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding
873 873
874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) 874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True)
875 lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) 875 lambda_ = min(1.0, 100 * lr)
876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) 876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm))
877 877
878 loop = partial( 878 loop = partial(
@@ -991,7 +991,7 @@ def main():
991 accelerator.backward(loss) 991 accelerator.backward(loss)
992 992
993 if accelerator.sync_gradients: 993 if accelerator.sync_gradients:
994 on_clip() 994 on_clip(lr_scheduler.get_last_lr()[0])
995 995
996 optimizer.step() 996 optimizer.step()
997 if not accelerator.optimizer_step_was_skipped: 997 if not accelerator.optimizer_step_was_skipped:
diff --git a/training/lr.py b/training/lr.py
index dfb1743..01f7f5e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -26,7 +26,7 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[], None] = noop, 29 on_clip: Callable[[float], None] = noop,
30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 31 ):
32 self.accelerator = accelerator 32 self.accelerator = accelerator
@@ -95,7 +95,7 @@ class LRFinder():
95 self.accelerator.backward(loss) 95 self.accelerator.backward(loss)
96 96
97 if self.accelerator.sync_gradients: 97 if self.accelerator.sync_gradients:
98 self.on_clip() 98 self.on_clip(lr_scheduler.get_last_lr()[0])
99 99
100 self.optimizer.step() 100 self.optimizer.step()
101 lr_scheduler.step() 101 lr_scheduler.step()