From dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 13:48:26 +0100 Subject: Training script improvements --- training/lr.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index ef01906..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py @@ -43,9 +43,6 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") - train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] - val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] - for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -54,7 +51,10 @@ class LRFinder(): self.model.train() - for batch in train_workload: + for step, batch in enumerate(self.train_dataloader): + if step >= num_train_batches: + break + with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -69,7 +69,10 @@ class LRFinder(): self.model.eval() with torch.inference_mode(): - for batch in val_workload: + for step, batch in enumerate(self.val_dataloader): + if step >= num_val_batches: + break + loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf