summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/training/lr.py b/training/lr.py
index c8dc040..3cdf994 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -12,7 +12,7 @@ from tqdm.auto import tqdm
12from training.util import AverageMeter 12from training.util import AverageMeter
13 13
14 14
15def noop(): 15def noop(*args, **kwards):
16 pass 16 pass
17 17
18 18
@@ -26,6 +26,7 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], None] = noop,
29 on_clip: Callable[[], None] = noop,
29 on_eval: Callable[[], None] = noop 30 on_eval: Callable[[], None] = noop
30 ): 31 ):
31 self.accelerator = accelerator 32 self.accelerator = accelerator
@@ -35,6 +36,7 @@ class LRFinder():
35 self.val_dataloader = val_dataloader 36 self.val_dataloader = val_dataloader
36 self.loss_fn = loss_fn 37 self.loss_fn = loss_fn
37 self.on_train = on_train 38 self.on_train = on_train
39 self.on_clip = on_clip
38 self.on_eval = on_eval 40 self.on_eval = on_eval
39 41
40 # self.model_state = copy.deepcopy(model.state_dict()) 42 # self.model_state = copy.deepcopy(model.state_dict())
@@ -93,6 +95,9 @@ class LRFinder():
93 95
94 self.accelerator.backward(loss) 96 self.accelerator.backward(loss)
95 97
98 if self.accelerator.sync_gradients:
99 self.on_clip()
100
96 self.optimizer.step() 101 self.optimizer.step()
97 lr_scheduler.step() 102 lr_scheduler.step()
98 self.optimizer.zero_grad(set_to_none=True) 103 self.optimizer.zero_grad(set_to_none=True)