summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py29
1 files changed, 15 insertions, 14 deletions
diff --git a/training/lr.py b/training/lr.py
index 7584ba2..902c4eb 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -9,6 +9,7 @@ import torch
9from torch.optim.lr_scheduler import LambdaLR 9from torch.optim.lr_scheduler import LambdaLR
10from tqdm.auto import tqdm 10from tqdm.auto import tqdm
11 11
12from training.functional import TrainingCallbacks
12from training.util import AverageMeter 13from training.util import AverageMeter
13 14
14 15
@@ -24,26 +25,19 @@ class LRFinder():
24 def __init__( 25 def __init__(
25 self, 26 self,
26 accelerator, 27 accelerator,
27 model,
28 optimizer, 28 optimizer,
29 train_dataloader, 29 train_dataloader,
30 val_dataloader, 30 val_dataloader,
31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
32 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, 32 callbacks: TrainingCallbacks = TrainingCallbacks()
33 on_before_optimize: Callable[[int], None] = noop,
34 on_after_optimize: Callable[[float], None] = noop,
35 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
36 ): 33 ):
37 self.accelerator = accelerator 34 self.accelerator = accelerator
38 self.model = model 35 self.model = callbacks.on_model()
39 self.optimizer = optimizer 36 self.optimizer = optimizer
40 self.train_dataloader = train_dataloader 37 self.train_dataloader = train_dataloader
41 self.val_dataloader = val_dataloader 38 self.val_dataloader = val_dataloader
42 self.loss_fn = loss_fn 39 self.loss_fn = loss_fn
43 self.on_train = on_train 40 self.callbacks = callbacks
44 self.on_before_optimize = on_before_optimize
45 self.on_after_optimize = on_after_optimize
46 self.on_eval = on_eval
47 41
48 # self.model_state = copy.deepcopy(model.state_dict()) 42 # self.model_state = copy.deepcopy(model.state_dict())
49 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 43 # self.optimizer_state = copy.deepcopy(optimizer.state_dict())
@@ -82,6 +76,13 @@ class LRFinder():
82 ) 76 )
83 progress_bar.set_description("Epoch X / Y") 77 progress_bar.set_description("Epoch X / Y")
84 78
79 self.callbacks.on_prepare()
80
81 on_train = self.callbacks.on_train
82 on_before_optimize = self.callbacks.on_before_optimize
83 on_after_optimize = self.callbacks.on_after_optimize
84 on_eval = self.callbacks.on_eval
85
85 for epoch in range(num_epochs): 86 for epoch in range(num_epochs):
86 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 87 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
87 88
@@ -90,7 +91,7 @@ class LRFinder():
90 91
91 self.model.train() 92 self.model.train()
92 93
93 with self.on_train(epoch): 94 with on_train(epoch):
94 for step, batch in enumerate(self.train_dataloader): 95 for step, batch in enumerate(self.train_dataloader):
95 if step >= num_train_batches: 96 if step >= num_train_batches:
96 break 97 break
@@ -100,21 +101,21 @@ class LRFinder():
100 101
101 self.accelerator.backward(loss) 102 self.accelerator.backward(loss)
102 103
103 self.on_before_optimize(epoch) 104 on_before_optimize(epoch)
104 105
105 self.optimizer.step() 106 self.optimizer.step()
106 lr_scheduler.step() 107 lr_scheduler.step()
107 self.optimizer.zero_grad(set_to_none=True) 108 self.optimizer.zero_grad(set_to_none=True)
108 109
109 if self.accelerator.sync_gradients: 110 if self.accelerator.sync_gradients:
110 self.on_after_optimize(lr_scheduler.get_last_lr()[0]) 111 on_after_optimize(lr_scheduler.get_last_lr()[0])
111 112
112 progress_bar.update(1) 113 progress_bar.update(1)
113 114
114 self.model.eval() 115 self.model.eval()
115 116
116 with torch.inference_mode(): 117 with torch.inference_mode():
117 with self.on_eval(): 118 with on_eval():
118 for step, batch in enumerate(self.val_dataloader): 119 for step, batch in enumerate(self.val_dataloader):
119 if step >= num_val_batches: 120 if step >= num_val_batches:
120 break 121 break