diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 51 | ||||
-rw-r--r-- | training/util.py | 2 |
2 files changed, 26 insertions, 27 deletions
diff --git a/training/lr.py b/training/lr.py index c765150..68e0f72 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,5 +1,5 @@ | |||
1 | import math | 1 | import math |
2 | import copy | 2 | from contextlib import _GeneratorContextManager, nullcontext |
3 | from typing import Callable, Any, Tuple, Union | 3 | from typing import Callable, Any, Tuple, Union |
4 | from functools import partial | 4 | from functools import partial |
5 | 5 | ||
@@ -25,9 +25,9 @@ class LRFinder(): | |||
25 | train_dataloader, | 25 | train_dataloader, |
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
29 | on_clip: Callable[[], None] = noop, | 29 | on_clip: Callable[[], None] = noop, |
30 | on_eval: Callable[[], None] = noop | 30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
31 | ): | 31 | ): |
32 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
33 | self.model = model | 33 | self.model = model |
@@ -51,7 +51,6 @@ class LRFinder(): | |||
51 | num_train_batches: int = 1, | 51 | num_train_batches: int = 1, |
52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
54 | diverge_th: int = 5, | ||
55 | ): | 54 | ): |
56 | best_loss = None | 55 | best_loss = None |
57 | best_acc = None | 56 | best_acc = None |
@@ -84,40 +83,40 @@ class LRFinder(): | |||
84 | avg_acc = AverageMeter() | 83 | avg_acc = AverageMeter() |
85 | 84 | ||
86 | self.model.train() | 85 | self.model.train() |
87 | self.on_train() | ||
88 | 86 | ||
89 | for step, batch in enumerate(self.train_dataloader): | 87 | with self.on_train(): |
90 | if step >= num_train_batches: | 88 | for step, batch in enumerate(self.train_dataloader): |
91 | break | 89 | if step >= num_train_batches: |
90 | break | ||
92 | 91 | ||
93 | with self.accelerator.accumulate(self.model): | 92 | with self.accelerator.accumulate(self.model): |
94 | loss, acc, bsz = self.loss_fn(step, batch) | 93 | loss, acc, bsz = self.loss_fn(step, batch) |
95 | 94 | ||
96 | self.accelerator.backward(loss) | 95 | self.accelerator.backward(loss) |
97 | 96 | ||
98 | if self.accelerator.sync_gradients: | 97 | if self.accelerator.sync_gradients: |
99 | self.on_clip() | 98 | self.on_clip() |
100 | 99 | ||
101 | self.optimizer.step() | 100 | self.optimizer.step() |
102 | lr_scheduler.step() | 101 | lr_scheduler.step() |
103 | self.optimizer.zero_grad(set_to_none=True) | 102 | self.optimizer.zero_grad(set_to_none=True) |
104 | 103 | ||
105 | if self.accelerator.sync_gradients: | 104 | if self.accelerator.sync_gradients: |
106 | progress_bar.update(1) | 105 | progress_bar.update(1) |
107 | 106 | ||
108 | self.model.eval() | 107 | self.model.eval() |
109 | self.on_eval() | ||
110 | 108 | ||
111 | with torch.inference_mode(): | 109 | with torch.inference_mode(): |
112 | for step, batch in enumerate(self.val_dataloader): | 110 | with self.on_eval(): |
113 | if step >= num_val_batches: | 111 | for step, batch in enumerate(self.val_dataloader): |
114 | break | 112 | if step >= num_val_batches: |
113 | break | ||
115 | 114 | ||
116 | loss, acc, bsz = self.loss_fn(step, batch, True) | 115 | loss, acc, bsz = self.loss_fn(step, batch, True) |
117 | avg_loss.update(loss.detach_(), bsz) | 116 | avg_loss.update(loss.detach_(), bsz) |
118 | avg_acc.update(acc.detach_(), bsz) | 117 | avg_acc.update(acc.detach_(), bsz) |
119 | 118 | ||
120 | progress_bar.update(1) | 119 | progress_bar.update(1) |
121 | 120 | ||
122 | loss = avg_loss.avg.item() | 121 | loss = avg_loss.avg.item() |
123 | acc = avg_acc.avg.item() | 122 | acc = avg_acc.avg.item() |
diff --git a/training/util.py b/training/util.py index 6f1e85a..bed7111 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -262,7 +262,7 @@ class EMAModel: | |||
262 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
263 | 263 | ||
264 | @contextmanager | 264 | @contextmanager |
265 | def apply_temporary(self, parameters): | 265 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
266 | try: | 266 | try: |
267 | parameters = list(parameters) | 267 | parameters = list(parameters) |
268 | original_params = [p.clone() for p in parameters] | 268 | original_params = [p.clone() for p in parameters] |