diff options
-rw-r--r-- | train_dreambooth.py | 4 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/lr.py | 38 |
3 files changed, 26 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1e49474..218018b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -634,7 +634,7 @@ def main(): | |||
634 | ) | 634 | ) |
635 | 635 | ||
636 | if args.find_lr: | 636 | if args.find_lr: |
637 | args.learning_rate = 1e2 | 637 | args.learning_rate = 1e-4 |
638 | 638 | ||
639 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 639 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
640 | if args.use_8bit_adam: | 640 | if args.use_8bit_adam: |
@@ -901,7 +901,7 @@ def main(): | |||
901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), |
902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) |
903 | ) | 903 | ) |
904 | lr_finder.run(min_lr=1e-4) | 904 | lr_finder.run(end_lr=1e2) |
905 | 905 | ||
906 | plt.savefig(basepath.joinpath("lr.png")) | 906 | plt.savefig(basepath.joinpath("lr.png")) |
907 | plt.close() | 907 | plt.close() |
diff --git a/train_ti.py b/train_ti.py index 2b3f017..102c0fa 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -584,7 +584,7 @@ def main(): | |||
584 | ) | 584 | ) |
585 | 585 | ||
586 | if args.find_lr: | 586 | if args.find_lr: |
587 | args.learning_rate = 1e2 | 587 | args.learning_rate = 1e-4 |
588 | 588 | ||
589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
590 | if args.use_8bit_adam: | 590 | if args.use_8bit_adam: |
@@ -853,9 +853,9 @@ def main(): | |||
853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), |
854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) |
855 | ) | 855 | ) |
856 | lr_finder.run(min_lr=1e-4) | 856 | lr_finder.run(end_lr=1e2) |
857 | 857 | ||
858 | plt.savefig(basepath.joinpath("lr.png")) | 858 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
859 | plt.close() | 859 | plt.close() |
860 | 860 | ||
861 | quit() | 861 | quit() |
diff --git a/training/lr.py b/training/lr.py index fe166ed..acc01a2 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import copy | 2 | import copy |
3 | from typing import Callable | 3 | from typing import Callable |
4 | from functools import partial | ||
4 | 5 | ||
5 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
6 | import numpy as np | 7 | import numpy as np |
@@ -41,7 +42,7 @@ class LRFinder(): | |||
41 | 42 | ||
42 | def run( | 43 | def run( |
43 | self, | 44 | self, |
44 | min_lr, | 45 | end_lr, |
45 | skip_start: int = 10, | 46 | skip_start: int = 10, |
46 | skip_end: int = 5, | 47 | skip_end: int = 5, |
47 | num_epochs: int = 100, | 48 | num_epochs: int = 100, |
@@ -57,7 +58,7 @@ class LRFinder(): | |||
57 | losses = [] | 58 | losses = [] |
58 | accs = [] | 59 | accs = [] |
59 | 60 | ||
60 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) |
61 | 62 | ||
62 | steps = min(num_train_batches, len(self.train_dataloader)) | 63 | steps = min(num_train_batches, len(self.train_dataloader)) |
63 | steps += min(num_val_batches, len(self.val_dataloader)) | 64 | steps += min(num_val_batches, len(self.val_dataloader)) |
@@ -152,29 +153,30 @@ class LRFinder(): | |||
152 | print("Stopping early, the loss has diverged") | 153 | print("Stopping early, the loss has diverged") |
153 | break | 154 | break |
154 | 155 | ||
155 | if skip_end == 0: | ||
156 | lrs = lrs[skip_start:] | ||
157 | losses = losses[skip_start:] | ||
158 | accs = accs[skip_start:] | ||
159 | else: | ||
160 | lrs = lrs[skip_start:-skip_end] | ||
161 | losses = losses[skip_start:-skip_end] | ||
162 | accs = accs[skip_start:-skip_end] | ||
163 | |||
164 | fig, ax_loss = plt.subplots() | 156 | fig, ax_loss = plt.subplots() |
157 | ax_acc = ax_loss.twinx() | ||
165 | 158 | ||
166 | ax_loss.plot(lrs, losses, color='red') | 159 | ax_loss.plot(lrs, losses, color='red') |
167 | ax_loss.set_xscale("log") | 160 | ax_loss.set_xscale("log") |
168 | ax_loss.set_xlabel("Learning rate") | 161 | ax_loss.set_xlabel(f"Learning rate") |
169 | ax_loss.set_ylabel("Loss") | 162 | ax_loss.set_ylabel("Loss") |
170 | 163 | ||
171 | ax_acc = ax_loss.twinx() | ||
172 | ax_acc.plot(lrs, accs, color='blue') | 164 | ax_acc.plot(lrs, accs, color='blue') |
165 | ax_acc.set_xscale("log") | ||
173 | ax_acc.set_ylabel("Accuracy") | 166 | ax_acc.set_ylabel("Accuracy") |
174 | 167 | ||
175 | print("LR suggestion: steepest gradient") | 168 | print("LR suggestion: steepest gradient") |
176 | min_grad_idx = None | 169 | min_grad_idx = None |
177 | 170 | ||
171 | if skip_end == 0: | ||
172 | lrs = lrs[skip_start:] | ||
173 | losses = losses[skip_start:] | ||
174 | accs = accs[skip_start:] | ||
175 | else: | ||
176 | lrs = lrs[skip_start:-skip_end] | ||
177 | losses = losses[skip_start:-skip_end] | ||
178 | accs = accs[skip_start:-skip_end] | ||
179 | |||
178 | try: | 180 | try: |
179 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 181 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
180 | except ValueError: | 182 | except ValueError: |
@@ -196,8 +198,10 @@ class LRFinder(): | |||
196 | ax_loss.legend() | 198 | ax_loss.legend() |
197 | 199 | ||
198 | 200 | ||
199 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 201 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): |
200 | def lr_lambda(current_epoch: int): | 202 | def lr_lambda(base_lr: float, current_epoch: int): |
201 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) | 203 | return (end_lr / base_lr) ** (current_epoch / num_epochs) |
204 | |||
205 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
202 | 206 | ||
203 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 207 | return LambdaLR(optimizer, lr_lambdas, last_epoch) |