summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lr.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/training/lr.py b/training/lr.py
index ef01906..0c5ce9e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -43,9 +43,6 @@ class LRFinder():
43 ) 43 )
44 progress_bar.set_description("Epoch X / Y") 44 progress_bar.set_description("Epoch X / Y")
45 45
46 train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches]
47 val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches]
48
49 for epoch in range(num_epochs): 46 for epoch in range(num_epochs):
50 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
51 48
@@ -54,7 +51,10 @@ class LRFinder():
54 51
55 self.model.train() 52 self.model.train()
56 53
57 for batch in train_workload: 54 for step, batch in enumerate(self.train_dataloader):
55 if step >= num_train_batches:
56 break
57
58 with self.accelerator.accumulate(self.model): 58 with self.accelerator.accumulate(self.model):
59 loss, acc, bsz = self.loss_fn(batch) 59 loss, acc, bsz = self.loss_fn(batch)
60 60
@@ -69,7 +69,10 @@ class LRFinder():
69 self.model.eval() 69 self.model.eval()
70 70
71 with torch.inference_mode(): 71 with torch.inference_mode():
72 for batch in val_workload: 72 for step, batch in enumerate(self.val_dataloader):
73 if step >= num_val_batches:
74 break
75
73 loss, acc, bsz = self.loss_fn(batch) 76 loss, acc, bsz = self.loss_fn(batch)
74 avg_loss.update(loss.detach_(), bsz) 77 avg_loss.update(loss.detach_(), bsz)
75 avg_acc.update(acc.detach_(), bsz) 78 avg_acc.update(acc.detach_(), bsz)