diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
commit | 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch) | |
tree | 61cb98adbf33ed08506601f8b70f1b62bc42c4ee /training/lr.py | |
parent | Simplified step calculations (diff) | |
download | textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2 textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip |
More modularization
Diffstat (limited to 'training/lr.py')
-rw-r--r-- | training/lr.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/training/lr.py b/training/lr.py index 84e30a0..7584ba2 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -16,6 +16,10 @@ def noop(*args, **kwards): | |||
16 | pass | 16 | pass |
17 | 17 | ||
18 | 18 | ||
19 | def noop_ctx(*args, **kwards): | ||
20 | return nullcontext() | ||
21 | |||
22 | |||
19 | class LRFinder(): | 23 | class LRFinder(): |
20 | def __init__( | 24 | def __init__( |
21 | self, | 25 | self, |
@@ -25,10 +29,10 @@ class LRFinder(): | |||
25 | train_dataloader, | 29 | train_dataloader, |
26 | val_dataloader, | 30 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 32 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, |
29 | on_before_optimize: Callable[[], None] = noop, | 33 | on_before_optimize: Callable[[int], None] = noop, |
30 | on_after_optimize: Callable[[float], None] = noop, | 34 | on_after_optimize: Callable[[float], None] = noop, |
31 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 35 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx |
32 | ): | 36 | ): |
33 | self.accelerator = accelerator | 37 | self.accelerator = accelerator |
34 | self.model = model | 38 | self.model = model |
@@ -86,7 +90,7 @@ class LRFinder(): | |||
86 | 90 | ||
87 | self.model.train() | 91 | self.model.train() |
88 | 92 | ||
89 | with self.on_train(): | 93 | with self.on_train(epoch): |
90 | for step, batch in enumerate(self.train_dataloader): | 94 | for step, batch in enumerate(self.train_dataloader): |
91 | if step >= num_train_batches: | 95 | if step >= num_train_batches: |
92 | break | 96 | break |
@@ -96,7 +100,7 @@ class LRFinder(): | |||
96 | 100 | ||
97 | self.accelerator.backward(loss) | 101 | self.accelerator.backward(loss) |
98 | 102 | ||
99 | self.on_before_optimize() | 103 | self.on_before_optimize(epoch) |
100 | 104 | ||
101 | self.optimizer.step() | 105 | self.optimizer.step() |
102 | lr_scheduler.step() | 106 | lr_scheduler.step() |