diff options
Diffstat (limited to 'training/lr.py')
-rw-r--r-- | training/lr.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..ef01906 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -43,6 +43,9 @@ class LRFinder(): | |||
43 | ) | 43 | ) |
44 | progress_bar.set_description("Epoch X / Y") | 44 | progress_bar.set_description("Epoch X / Y") |
45 | 45 | ||
46 | train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] | ||
47 | val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] | ||
48 | |||
46 | for epoch in range(num_epochs): | 49 | for epoch in range(num_epochs): |
47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 50 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
48 | 51 | ||
@@ -51,10 +54,7 @@ class LRFinder(): | |||
51 | 54 | ||
52 | self.model.train() | 55 | self.model.train() |
53 | 56 | ||
54 | for step, batch in enumerate(self.train_dataloader): | 57 | for batch in train_workload: |
55 | if step >= num_train_batches: | ||
56 | break | ||
57 | |||
58 | with self.accelerator.accumulate(self.model): | 58 | with self.accelerator.accumulate(self.model): |
59 | loss, acc, bsz = self.loss_fn(batch) | 59 | loss, acc, bsz = self.loss_fn(batch) |
60 | 60 | ||
@@ -69,10 +69,7 @@ class LRFinder(): | |||
69 | self.model.eval() | 69 | self.model.eval() |
70 | 70 | ||
71 | with torch.inference_mode(): | 71 | with torch.inference_mode(): |
72 | for step, batch in enumerate(self.val_dataloader): | 72 | for batch in val_workload: |
73 | if step >= num_val_batches: | ||
74 | break | ||
75 | |||
76 | loss, acc, bsz = self.loss_fn(batch) | 73 | loss, acc, bsz = self.loss_fn(batch) |
77 | avg_loss.update(loss.detach_(), bsz) | 74 | avg_loss.update(loss.detach_(), bsz) |
78 | avg_acc.update(acc.detach_(), bsz) | 75 | avg_acc.update(acc.detach_(), bsz) |