summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_ti.py6
-rw-r--r--training/lr.py38
3 files changed, 26 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1e49474..218018b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -634,7 +634,7 @@ def main():
634 ) 634 )
635 635
636 if args.find_lr: 636 if args.find_lr:
637 args.learning_rate = 1e2 637 args.learning_rate = 1e-4
638 638
639 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 639 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
640 if args.use_8bit_adam: 640 if args.use_8bit_adam:
@@ -901,7 +901,7 @@ def main():
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
903 ) 903 )
904 lr_finder.run(min_lr=1e-4) 904 lr_finder.run(end_lr=1e2)
905 905
906 plt.savefig(basepath.joinpath("lr.png")) 906 plt.savefig(basepath.joinpath("lr.png"))
907 plt.close() 907 plt.close()
diff --git a/train_ti.py b/train_ti.py
index 2b3f017..102c0fa 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -584,7 +584,7 @@ def main():
584 ) 584 )
585 585
586 if args.find_lr: 586 if args.find_lr:
587 args.learning_rate = 1e2 587 args.learning_rate = 1e-4
588 588
589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
590 if args.use_8bit_adam: 590 if args.use_8bit_adam:
@@ -853,9 +853,9 @@ def main():
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
855 ) 855 )
856 lr_finder.run(min_lr=1e-4) 856 lr_finder.run(end_lr=1e2)
857 857
858 plt.savefig(basepath.joinpath("lr.png")) 858 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
859 plt.close() 859 plt.close()
860 860
861 quit() 861 quit()
diff --git a/training/lr.py b/training/lr.py
index fe166ed..acc01a2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,6 +1,7 @@
1import math 1import math
2import copy 2import copy
3from typing import Callable 3from typing import Callable
4from functools import partial
4 5
5import matplotlib.pyplot as plt 6import matplotlib.pyplot as plt
6import numpy as np 7import numpy as np
@@ -41,7 +42,7 @@ class LRFinder():
41 42
42 def run( 43 def run(
43 self, 44 self,
44 min_lr, 45 end_lr,
45 skip_start: int = 10, 46 skip_start: int = 10,
46 skip_end: int = 5, 47 skip_end: int = 5,
47 num_epochs: int = 100, 48 num_epochs: int = 100,
@@ -57,7 +58,7 @@ class LRFinder():
57 losses = [] 58 losses = []
58 accs = [] 59 accs = []
59 60
60 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) 61 lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs)
61 62
62 steps = min(num_train_batches, len(self.train_dataloader)) 63 steps = min(num_train_batches, len(self.train_dataloader))
63 steps += min(num_val_batches, len(self.val_dataloader)) 64 steps += min(num_val_batches, len(self.val_dataloader))
@@ -152,29 +153,30 @@ class LRFinder():
152 print("Stopping early, the loss has diverged") 153 print("Stopping early, the loss has diverged")
153 break 154 break
154 155
155 if skip_end == 0:
156 lrs = lrs[skip_start:]
157 losses = losses[skip_start:]
158 accs = accs[skip_start:]
159 else:
160 lrs = lrs[skip_start:-skip_end]
161 losses = losses[skip_start:-skip_end]
162 accs = accs[skip_start:-skip_end]
163
164 fig, ax_loss = plt.subplots() 156 fig, ax_loss = plt.subplots()
157 ax_acc = ax_loss.twinx()
165 158
166 ax_loss.plot(lrs, losses, color='red') 159 ax_loss.plot(lrs, losses, color='red')
167 ax_loss.set_xscale("log") 160 ax_loss.set_xscale("log")
168 ax_loss.set_xlabel("Learning rate") 161 ax_loss.set_xlabel(f"Learning rate")
169 ax_loss.set_ylabel("Loss") 162 ax_loss.set_ylabel("Loss")
170 163
171 ax_acc = ax_loss.twinx()
172 ax_acc.plot(lrs, accs, color='blue') 164 ax_acc.plot(lrs, accs, color='blue')
165 ax_acc.set_xscale("log")
173 ax_acc.set_ylabel("Accuracy") 166 ax_acc.set_ylabel("Accuracy")
174 167
175 print("LR suggestion: steepest gradient") 168 print("LR suggestion: steepest gradient")
176 min_grad_idx = None 169 min_grad_idx = None
177 170
171 if skip_end == 0:
172 lrs = lrs[skip_start:]
173 losses = losses[skip_start:]
174 accs = accs[skip_start:]
175 else:
176 lrs = lrs[skip_start:-skip_end]
177 losses = losses[skip_start:-skip_end]
178 accs = accs[skip_start:-skip_end]
179
178 try: 180 try:
179 min_grad_idx = (np.gradient(np.array(losses))).argmin() 181 min_grad_idx = (np.gradient(np.array(losses))).argmin()
180 except ValueError: 182 except ValueError:
@@ -196,8 +198,10 @@ class LRFinder():
196 ax_loss.legend() 198 ax_loss.legend()
197 199
198 200
199def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): 201def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1):
200 def lr_lambda(current_epoch: int): 202 def lr_lambda(base_lr: float, current_epoch: int):
201 return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) 203 return (end_lr / base_lr) ** (current_epoch / num_epochs)
204
205 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
202 206
203 return LambdaLR(optimizer, lr_lambda, last_epoch) 207 return LambdaLR(optimizer, lr_lambdas, last_epoch)