diff options
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/lr.py | 4 |
2 files changed, 5 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index e6c437e..951f3dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -868,11 +868,11 @@ def main(): | |||
868 | pass | 868 | pass |
869 | 869 | ||
870 | @torch.no_grad() | 870 | @torch.no_grad() |
871 | def on_clip(): | 871 | def on_clip(lr): |
872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding | 872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding |
873 | 873 | ||
874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) | 874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) |
875 | lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) | 875 | lambda_ = min(1.0, 100 * lr) |
876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) | 876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) |
877 | 877 | ||
878 | loop = partial( | 878 | loop = partial( |
@@ -991,7 +991,7 @@ def main(): | |||
991 | accelerator.backward(loss) | 991 | accelerator.backward(loss) |
992 | 992 | ||
993 | if accelerator.sync_gradients: | 993 | if accelerator.sync_gradients: |
994 | on_clip() | 994 | on_clip(lr_scheduler.get_last_lr()[0]) |
995 | 995 | ||
996 | optimizer.step() | 996 | optimizer.step() |
997 | if not accelerator.optimizer_step_was_skipped: | 997 | if not accelerator.optimizer_step_was_skipped: |
diff --git a/training/lr.py b/training/lr.py index dfb1743..01f7f5e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -26,7 +26,7 @@ class LRFinder(): | |||
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
29 | on_clip: Callable[[], None] = noop, | 29 | on_clip: Callable[[float], None] = noop, |
30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
31 | ): | 31 | ): |
32 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
@@ -95,7 +95,7 @@ class LRFinder(): | |||
95 | self.accelerator.backward(loss) | 95 | self.accelerator.backward(loss) |
96 | 96 | ||
97 | if self.accelerator.sync_gradients: | 97 | if self.accelerator.sync_gradients: |
98 | self.on_clip() | 98 | self.on_clip(lr_scheduler.get_last_lr()[0]) |
99 | 99 | ||
100 | self.optimizer.step() | 100 | self.optimizer.step() |
101 | lr_scheduler.step() | 101 | lr_scheduler.step() |