From f87d9fdf541b0282249ddde1dc0302317350f998 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 29 Dec 2022 15:28:02 +0100 Subject: Update --- training/lr.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..ef01906 100644 --- a/training/lr.py +++ b/training/lr.py @@ -43,6 +43,9 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") + train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] + val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] + for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -51,10 +54,7 @@ class LRFinder(): self.model.train() - for step, batch in enumerate(self.train_dataloader): - if step >= num_train_batches: - break - + for batch in train_workload: with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -69,10 +69,7 @@ class LRFinder(): self.model.eval() with torch.inference_mode(): - for step, batch in enumerate(self.val_dataloader): - if step >= num_val_batches: - break - + for batch in val_workload: loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf