summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lr.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/training/lr.py b/training/lr.py
index 01f7f5e..84e30a0 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -26,7 +26,8 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[float], None] = noop, 29 on_before_optimize: Callable[[], None] = noop,
30 on_after_optimize: Callable[[float], None] = noop,
30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 31 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 32 ):
32 self.accelerator = accelerator 33 self.accelerator = accelerator
@@ -36,7 +37,8 @@ class LRFinder():
36 self.val_dataloader = val_dataloader 37 self.val_dataloader = val_dataloader
37 self.loss_fn = loss_fn 38 self.loss_fn = loss_fn
38 self.on_train = on_train 39 self.on_train = on_train
39 self.on_clip = on_clip 40 self.on_before_optimize = on_before_optimize
41 self.on_after_optimize = on_after_optimize
40 self.on_eval = on_eval 42 self.on_eval = on_eval
41 43
42 # self.model_state = copy.deepcopy(model.state_dict()) 44 # self.model_state = copy.deepcopy(model.state_dict())
@@ -94,14 +96,15 @@ class LRFinder():
94 96
95 self.accelerator.backward(loss) 97 self.accelerator.backward(loss)
96 98
97 if self.accelerator.sync_gradients: 99 self.on_before_optimize()
98 self.on_clip(lr_scheduler.get_last_lr()[0])
99 100
100 self.optimizer.step() 101 self.optimizer.step()
101 lr_scheduler.step() 102 lr_scheduler.step()
102 self.optimizer.zero_grad(set_to_none=True) 103 self.optimizer.zero_grad(set_to_none=True)
103 104
104 if self.accelerator.sync_gradients: 105 if self.accelerator.sync_gradients:
106 self.on_after_optimize(lr_scheduler.get_last_lr()[0])
107
105 progress_bar.update(1) 108 progress_bar.update(1)
106 109
107 self.model.eval() 110 self.model.eval()