summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-29 15:28:02 +0100
committerVolpeon <git@volpeon.ink>2022-12-29 15:28:02 +0100
commitf87d9fdf541b0282249ddde1dc0302317350f998 (patch)
treea27f4319d90098f026784711ffd1a415fa561def /training/lr.py
parentTraining improvements (diff)
downloadtextual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.tar.gz
textual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.tar.bz2
textual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.zip
Update
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/training/lr.py b/training/lr.py
index 0c5ce9e..ef01906 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -43,6 +43,9 @@ class LRFinder():
43 ) 43 )
44 progress_bar.set_description("Epoch X / Y") 44 progress_bar.set_description("Epoch X / Y")
45 45
46 train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches]
47 val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches]
48
46 for epoch in range(num_epochs): 49 for epoch in range(num_epochs):
47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 50 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
48 51
@@ -51,10 +54,7 @@ class LRFinder():
51 54
52 self.model.train() 55 self.model.train()
53 56
54 for step, batch in enumerate(self.train_dataloader): 57 for batch in train_workload:
55 if step >= num_train_batches:
56 break
57
58 with self.accelerator.accumulate(self.model): 58 with self.accelerator.accumulate(self.model):
59 loss, acc, bsz = self.loss_fn(batch) 59 loss, acc, bsz = self.loss_fn(batch)
60 60
@@ -69,10 +69,7 @@ class LRFinder():
69 self.model.eval() 69 self.model.eval()
70 70
71 with torch.inference_mode(): 71 with torch.inference_mode():
72 for step, batch in enumerate(self.val_dataloader): 72 for batch in val_workload:
73 if step >= num_val_batches:
74 break
75
76 loss, acc, bsz = self.loss_fn(batch) 73 loss, acc, bsz = self.loss_fn(batch)
77 avg_loss.update(loss.detach_(), bsz) 74 avg_loss.update(loss.detach_(), bsz)
78 avg_acc.update(acc.detach_(), bsz) 75 avg_acc.update(acc.detach_(), bsz)