summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 10:32:58 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 10:32:58 +0100
commitbed44095ab99440467c2f302899b970c92baebf8 (patch)
tree2b469fe74e0dc22f0fa38413c69135952363f2af /training/lr.py
parentFixed reproducibility, more consistant validation (diff)
downloadtextual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.gz
textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.tar.bz2
textual-inversion-diff-bed44095ab99440467c2f302899b970c92baebf8.zip
Better eval generator
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py
index a3144ba..c8dc040 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -24,7 +24,7 @@ class LRFinder():
24 optimizer, 24 optimizer,
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], None] = noop,
29 on_eval: Callable[[], None] = noop 29 on_eval: Callable[[], None] = noop
30 ): 30 ):
@@ -89,7 +89,7 @@ class LRFinder():
89 break 89 break
90 90
91 with self.accelerator.accumulate(self.model): 91 with self.accelerator.accumulate(self.model):
92 loss, acc, bsz = self.loss_fn(batch) 92 loss, acc, bsz = self.loss_fn(step, batch)
93 93
94 self.accelerator.backward(loss) 94 self.accelerator.backward(loss)
95 95
@@ -108,7 +108,7 @@ class LRFinder():
108 if step >= num_val_batches: 108 if step >= num_val_batches:
109 break 109 break
110 110
111 loss, acc, bsz = self.loss_fn(batch, True) 111 loss, acc, bsz = self.loss_fn(step, batch, True)
112 avg_loss.update(loss.detach_(), bsz) 112 avg_loss.update(loss.detach_(), bsz)
113 avg_acc.update(acc.detach_(), bsz) 113 avg_acc.update(acc.detach_(), bsz)
114 114