diff options
author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /training | |
parent | Fixed LR finder (diff) | |
download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip |
Added vector dropout
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -58,7 +58,11 @@ class LRFinder(): | |||
58 | losses = [] | 58 | losses = [] |
59 | accs = [] | 59 | accs = [] |
60 | 60 | ||
61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule( |
62 | self.optimizer, | ||
63 | end_lr, | ||
64 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
65 | ) | ||
62 | 66 | ||
63 | steps = min(num_train_batches, len(self.train_dataloader)) | 67 | steps = min(num_train_batches, len(self.train_dataloader)) |
64 | steps += min(num_val_batches, len(self.val_dataloader)) | 68 | steps += min(num_val_batches, len(self.val_dataloader)) |
@@ -90,6 +94,7 @@ class LRFinder(): | |||
90 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
91 | 95 | ||
92 | self.optimizer.step() | 96 | self.optimizer.step() |
97 | lr_scheduler.step() | ||
93 | self.optimizer.zero_grad(set_to_none=True) | 98 | self.optimizer.zero_grad(set_to_none=True) |
94 | 99 | ||
95 | if self.accelerator.sync_gradients: | 100 | if self.accelerator.sync_gradients: |
@@ -109,8 +114,6 @@ class LRFinder(): | |||
109 | 114 | ||
110 | progress_bar.update(1) | 115 | progress_bar.update(1) |
111 | 116 | ||
112 | lr_scheduler.step() | ||
113 | |||
114 | loss = avg_loss.avg.item() | 117 | loss = avg_loss.avg.item() |
115 | acc = avg_acc.avg.item() | 118 | acc = avg_acc.avg.item() |
116 | 119 | ||