summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lr.py51
-rw-r--r--training/util.py2
2 files changed, 26 insertions, 27 deletions
diff --git a/training/lr.py b/training/lr.py
index c765150..68e0f72 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,5 +1,5 @@
1import math 1import math
2import copy 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 3from typing import Callable, Any, Tuple, Union
4from functools import partial 4from functools import partial
5 5
@@ -25,9 +25,9 @@ class LRFinder():
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[], None] = noop, 29 on_clip: Callable[[], None] = noop,
30 on_eval: Callable[[], None] = noop 30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 31 ):
32 self.accelerator = accelerator 32 self.accelerator = accelerator
33 self.model = model 33 self.model = model
@@ -51,7 +51,6 @@ class LRFinder():
51 num_train_batches: int = 1, 51 num_train_batches: int = 1,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 diverge_th: int = 5,
55 ): 54 ):
56 best_loss = None 55 best_loss = None
57 best_acc = None 56 best_acc = None
@@ -84,40 +83,40 @@ class LRFinder():
84 avg_acc = AverageMeter() 83 avg_acc = AverageMeter()
85 84
86 self.model.train() 85 self.model.train()
87 self.on_train()
88 86
89 for step, batch in enumerate(self.train_dataloader): 87 with self.on_train():
90 if step >= num_train_batches: 88 for step, batch in enumerate(self.train_dataloader):
91 break 89 if step >= num_train_batches:
90 break
92 91
93 with self.accelerator.accumulate(self.model): 92 with self.accelerator.accumulate(self.model):
94 loss, acc, bsz = self.loss_fn(step, batch) 93 loss, acc, bsz = self.loss_fn(step, batch)
95 94
96 self.accelerator.backward(loss) 95 self.accelerator.backward(loss)
97 96
98 if self.accelerator.sync_gradients: 97 if self.accelerator.sync_gradients:
99 self.on_clip() 98 self.on_clip()
100 99
101 self.optimizer.step() 100 self.optimizer.step()
102 lr_scheduler.step() 101 lr_scheduler.step()
103 self.optimizer.zero_grad(set_to_none=True) 102 self.optimizer.zero_grad(set_to_none=True)
104 103
105 if self.accelerator.sync_gradients: 104 if self.accelerator.sync_gradients:
106 progress_bar.update(1) 105 progress_bar.update(1)
107 106
108 self.model.eval() 107 self.model.eval()
109 self.on_eval()
110 108
111 with torch.inference_mode(): 109 with torch.inference_mode():
112 for step, batch in enumerate(self.val_dataloader): 110 with self.on_eval():
113 if step >= num_val_batches: 111 for step, batch in enumerate(self.val_dataloader):
114 break 112 if step >= num_val_batches:
113 break
115 114
116 loss, acc, bsz = self.loss_fn(step, batch, True) 115 loss, acc, bsz = self.loss_fn(step, batch, True)
117 avg_loss.update(loss.detach_(), bsz) 116 avg_loss.update(loss.detach_(), bsz)
118 avg_acc.update(acc.detach_(), bsz) 117 avg_acc.update(acc.detach_(), bsz)
119 118
120 progress_bar.update(1) 119 progress_bar.update(1)
121 120
122 loss = avg_loss.avg.item() 121 loss = avg_loss.avg.item()
123 acc = avg_acc.avg.item() 122 acc = avg_acc.avg.item()
diff --git a/training/util.py b/training/util.py
index 6f1e85a..bed7111 100644
--- a/training/util.py
+++ b/training/util.py
@@ -262,7 +262,7 @@ class EMAModel:
262 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263 263
264 @contextmanager 264 @contextmanager
265 def apply_temporary(self, parameters): 265 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
266 try: 266 try:
267 parameters = list(parameters) 267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters] 268 original_params = [p.clone() for p in parameters]