diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/training/lr.py b/training/lr.py index c8dc040..3cdf994 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -12,7 +12,7 @@ from tqdm.auto import tqdm | |||
| 12 | from training.util import AverageMeter | 12 | from training.util import AverageMeter |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | def noop(): | 15 | def noop(*args, **kwards): |
| 16 | pass | 16 | pass |
| 17 | 17 | ||
| 18 | 18 | ||
| @@ -26,6 +26,7 @@ class LRFinder(): | |||
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
| 29 | on_clip: Callable[[], None] = noop, | ||
| 29 | on_eval: Callable[[], None] = noop | 30 | on_eval: Callable[[], None] = noop |
| 30 | ): | 31 | ): |
| 31 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
| @@ -35,6 +36,7 @@ class LRFinder(): | |||
| 35 | self.val_dataloader = val_dataloader | 36 | self.val_dataloader = val_dataloader |
| 36 | self.loss_fn = loss_fn | 37 | self.loss_fn = loss_fn |
| 37 | self.on_train = on_train | 38 | self.on_train = on_train |
| 39 | self.on_clip = on_clip | ||
| 38 | self.on_eval = on_eval | 40 | self.on_eval = on_eval |
| 39 | 41 | ||
| 40 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) |
| @@ -93,6 +95,9 @@ class LRFinder(): | |||
| 93 | 95 | ||
| 94 | self.accelerator.backward(loss) | 96 | self.accelerator.backward(loss) |
| 95 | 97 | ||
| 98 | if self.accelerator.sync_gradients: | ||
| 99 | self.on_clip() | ||
| 100 | |||
| 96 | self.optimizer.step() | 101 | self.optimizer.step() |
| 97 | lr_scheduler.step() | 102 | lr_scheduler.step() |
| 98 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |
