diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/training/lr.py b/training/lr.py index ef01906..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -43,9 +43,6 @@ class LRFinder(): | |||
43 | ) | 43 | ) |
44 | progress_bar.set_description("Epoch X / Y") | 44 | progress_bar.set_description("Epoch X / Y") |
45 | 45 | ||
46 | train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] | ||
47 | val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] | ||
48 | |||
49 | for epoch in range(num_epochs): | 46 | for epoch in range(num_epochs): |
50 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
51 | 48 | ||
@@ -54,7 +51,10 @@ class LRFinder(): | |||
54 | 51 | ||
55 | self.model.train() | 52 | self.model.train() |
56 | 53 | ||
57 | for batch in train_workload: | 54 | for step, batch in enumerate(self.train_dataloader): |
55 | if step >= num_train_batches: | ||
56 | break | ||
57 | |||
58 | with self.accelerator.accumulate(self.model): | 58 | with self.accelerator.accumulate(self.model): |
59 | loss, acc, bsz = self.loss_fn(batch) | 59 | loss, acc, bsz = self.loss_fn(batch) |
60 | 60 | ||
@@ -69,7 +69,10 @@ class LRFinder(): | |||
69 | self.model.eval() | 69 | self.model.eval() |
70 | 70 | ||
71 | with torch.inference_mode(): | 71 | with torch.inference_mode(): |
72 | for batch in val_workload: | 72 | for step, batch in enumerate(self.val_dataloader): |
73 | if step >= num_val_batches: | ||
74 | break | ||
75 | |||
73 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
74 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
75 | avg_acc.update(acc.detach_(), bsz) | 78 | avg_acc.update(acc.detach_(), bsz) |