summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
committerVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
commit33e7d2ed37e32657ca94d92815043026c4cea7c0 (patch)
tree0af4d6ad0ba92a168e3ec17675147c76afe1baf0 /training/lr.py
parentEnable buckets for validation, fixed vaildation repeat arg (diff)
downloadtextual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.gz
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.bz2
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.zip
Added arg to disable tag shuffling
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/training/lr.py b/training/lr.py
index 68e0f72..dfb1743 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -48,7 +48,7 @@ class LRFinder():
48 skip_start: int = 10, 48 skip_start: int = 10,
49 skip_end: int = 5, 49 skip_end: int = 5,
50 num_epochs: int = 100, 50 num_epochs: int = 100,
51 num_train_batches: int = 1, 51 num_train_batches: int = math.inf,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 ): 54 ):
@@ -156,6 +156,15 @@ class LRFinder():
156 # self.model.load_state_dict(self.model_state) 156 # self.model.load_state_dict(self.model_state)
157 # self.optimizer.load_state_dict(self.optimizer_state) 157 # self.optimizer.load_state_dict(self.optimizer_state)
158 158
159 if skip_end == 0:
160 lrs = lrs[skip_start:]
161 losses = losses[skip_start:]
162 accs = accs[skip_start:]
163 else:
164 lrs = lrs[skip_start:-skip_end]
165 losses = losses[skip_start:-skip_end]
166 accs = accs[skip_start:-skip_end]
167
159 fig, ax_loss = plt.subplots() 168 fig, ax_loss = plt.subplots()
160 ax_acc = ax_loss.twinx() 169 ax_acc = ax_loss.twinx()
161 170
@@ -171,15 +180,6 @@ class LRFinder():
171 print("LR suggestion: steepest gradient") 180 print("LR suggestion: steepest gradient")
172 min_grad_idx = None 181 min_grad_idx = None
173 182
174 if skip_end == 0:
175 lrs = lrs[skip_start:]
176 losses = losses[skip_start:]
177 accs = accs[skip_start:]
178 else:
179 lrs = lrs[skip_start:-skip_end]
180 losses = losses[skip_start:-skip_end]
181 accs = accs[skip_start:-skip_end]
182
183 try: 183 try:
184 min_grad_idx = np.gradient(np.array(losses)).argmin() 184 min_grad_idx = np.gradient(np.array(losses)).argmin()
185 except ValueError: 185 except ValueError: