diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e6c437e..951f3dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -868,11 +868,11 @@ def main(): | |||
| 868 | pass | 868 | pass |
| 869 | 869 | ||
| 870 | @torch.no_grad() | 870 | @torch.no_grad() |
| 871 | def on_clip(): | 871 | def on_clip(lr): |
| 872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding | 872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding |
| 873 | 873 | ||
| 874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) | 874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) |
| 875 | lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) | 875 | lambda_ = min(1.0, 100 * lr) |
| 876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) | 876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) |
| 877 | 877 | ||
| 878 | loop = partial( | 878 | loop = partial( |
| @@ -991,7 +991,7 @@ def main(): | |||
| 991 | accelerator.backward(loss) | 991 | accelerator.backward(loss) |
| 992 | 992 | ||
| 993 | if accelerator.sync_gradients: | 993 | if accelerator.sync_gradients: |
| 994 | on_clip() | 994 | on_clip(lr_scheduler.get_last_lr()[0]) |
| 995 | 995 | ||
| 996 | optimizer.step() | 996 | optimizer.step() |
| 997 | if not accelerator.optimizer_step_was_skipped: | 997 | if not accelerator.optimizer_step_was_skipped: |
