summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-11 22:29:28 +0100
committerVolpeon <git@volpeon.ink>2023-01-11 22:29:28 +0100
commitb89953bea7dfe6c92164888a66d05bc7d987ef71 (patch)
tree7cfb060de5cb981373572bc0c8dfd7152b9e9173 /training/lr.py
parentHeck (diff)
downloadtextual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.tar.gz
textual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.tar.bz2
textual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.zip
Fix
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/training/lr.py b/training/lr.py
index dfb1743..01f7f5e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -26,7 +26,7 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[], None] = noop, 29 on_clip: Callable[[float], None] = noop,
30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 31 ):
32 self.accelerator = accelerator 32 self.accelerator = accelerator
@@ -95,7 +95,7 @@ class LRFinder():
95 self.accelerator.backward(loss) 95 self.accelerator.backward(loss)
96 96
97 if self.accelerator.sync_gradients: 97 if self.accelerator.sync_gradients:
98 self.on_clip() 98 self.on_clip(lr_scheduler.get_last_lr()[0])
99 99
100 self.optimizer.step() 100 self.optimizer.step()
101 lr_scheduler.step() 101 lr_scheduler.step()