summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
commit672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch)
tree1afb3a943af3fa7c935d133cf2768a33f11f8235 /training/lr.py
parentPackage update (diff)
downloadtextual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip
Use context manager for EMA, on_train/eval hooks
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py51
1 files changed, 25 insertions, 26 deletions
diff --git a/training/lr.py b/training/lr.py
index c765150..68e0f72 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,5 +1,5 @@
1import math 1import math
2import copy 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 3from typing import Callable, Any, Tuple, Union
4from functools import partial 4from functools import partial
5 5
@@ -25,9 +25,9 @@ class LRFinder():
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[], None] = noop, 29 on_clip: Callable[[], None] = noop,
30 on_eval: Callable[[], None] = noop 30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 31 ):
32 self.accelerator = accelerator 32 self.accelerator = accelerator
33 self.model = model 33 self.model = model
@@ -51,7 +51,6 @@ class LRFinder():
51 num_train_batches: int = 1, 51 num_train_batches: int = 1,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 diverge_th: int = 5,
55 ): 54 ):
56 best_loss = None 55 best_loss = None
57 best_acc = None 56 best_acc = None
@@ -84,40 +83,40 @@ class LRFinder():
84 avg_acc = AverageMeter() 83 avg_acc = AverageMeter()
85 84
86 self.model.train() 85 self.model.train()
87 self.on_train()
88 86
89 for step, batch in enumerate(self.train_dataloader): 87 with self.on_train():
90 if step >= num_train_batches: 88 for step, batch in enumerate(self.train_dataloader):
91 break 89 if step >= num_train_batches:
90 break
92 91
93 with self.accelerator.accumulate(self.model): 92 with self.accelerator.accumulate(self.model):
94 loss, acc, bsz = self.loss_fn(step, batch) 93 loss, acc, bsz = self.loss_fn(step, batch)
95 94
96 self.accelerator.backward(loss) 95 self.accelerator.backward(loss)
97 96
98 if self.accelerator.sync_gradients: 97 if self.accelerator.sync_gradients:
99 self.on_clip() 98 self.on_clip()
100 99
101 self.optimizer.step() 100 self.optimizer.step()
102 lr_scheduler.step() 101 lr_scheduler.step()
103 self.optimizer.zero_grad(set_to_none=True) 102 self.optimizer.zero_grad(set_to_none=True)
104 103
105 if self.accelerator.sync_gradients: 104 if self.accelerator.sync_gradients:
106 progress_bar.update(1) 105 progress_bar.update(1)
107 106
108 self.model.eval() 107 self.model.eval()
109 self.on_eval()
110 108
111 with torch.inference_mode(): 109 with torch.inference_mode():
112 for step, batch in enumerate(self.val_dataloader): 110 with self.on_eval():
113 if step >= num_val_batches: 111 for step, batch in enumerate(self.val_dataloader):
114 break 112 if step >= num_val_batches:
113 break
115 114
116 loss, acc, bsz = self.loss_fn(step, batch, True) 115 loss, acc, bsz = self.loss_fn(step, batch, True)
117 avg_loss.update(loss.detach_(), bsz) 116 avg_loss.update(loss.detach_(), bsz)
118 avg_acc.update(acc.detach_(), bsz) 117 avg_acc.update(acc.detach_(), bsz)
119 118
120 progress_bar.update(1) 119 progress_bar.update(1)
121 120
122 loss = avg_loss.avg.item() 121 loss = avg_loss.avg.item()
123 acc = avg_acc.avg.item() 122 acc = avg_acc.avg.item()