summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee /training/lr.py
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/training/lr.py b/training/lr.py
index 84e30a0..7584ba2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -16,6 +16,10 @@ def noop(*args, **kwards):
16 pass 16 pass
17 17
18 18
19def noop_ctx(*args, **kwards):
20 return nullcontext()
21
22
19class LRFinder(): 23class LRFinder():
20 def __init__( 24 def __init__(
21 self, 25 self,
@@ -25,10 +29,10 @@ class LRFinder():
25 train_dataloader, 29 train_dataloader,
26 val_dataloader, 30 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 32 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
29 on_before_optimize: Callable[[], None] = noop, 33 on_before_optimize: Callable[[int], None] = noop,
30 on_after_optimize: Callable[[float], None] = noop, 34 on_after_optimize: Callable[[float], None] = noop,
31 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 35 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
32 ): 36 ):
33 self.accelerator = accelerator 37 self.accelerator = accelerator
34 self.model = model 38 self.model = model
@@ -86,7 +90,7 @@ class LRFinder():
86 90
87 self.model.train() 91 self.model.train()
88 92
89 with self.on_train(): 93 with self.on_train(epoch):
90 for step, batch in enumerate(self.train_dataloader): 94 for step, batch in enumerate(self.train_dataloader):
91 if step >= num_train_batches: 95 if step >= num_train_batches:
92 break 96 break
@@ -96,7 +100,7 @@ class LRFinder():
96 100
97 self.accelerator.backward(loss) 101 self.accelerator.backward(loss)
98 102
99 self.on_before_optimize() 103 self.on_before_optimize(epoch)
100 104
101 self.optimizer.step() 105 self.optimizer.step()
102 lr_scheduler.step() 106 lr_scheduler.step()