diff options
-rw-r--r-- | train_dreambooth.py | 5 | ||||
-rw-r--r-- | train_ti.py | 10 | ||||
-rw-r--r-- | training/lr.py | 38 |
3 files changed, 30 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a62cec9..325fe90 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -970,9 +970,8 @@ def main(): | |||
970 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
971 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
972 | 972 | ||
973 | if accelerator.sync_gradients: | 973 | local_progress_bar.update(1) |
974 | local_progress_bar.update(1) | 974 | global_progress_bar.update(1) |
975 | global_progress_bar.update(1) | ||
976 | 975 | ||
977 | logs = { | 976 | logs = { |
978 | "val/loss": avg_loss_val.avg.item(), | 977 | "val/loss": avg_loss_val.avg.item(), |
diff --git a/train_ti.py b/train_ti.py index 32f44f4..870b2ba 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -548,9 +548,6 @@ def main(): | |||
548 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
549 | ) | 549 | ) |
550 | 550 | ||
551 | if args.find_lr: | ||
552 | args.learning_rate = 1e2 | ||
553 | |||
554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
555 | if args.use_8bit_adam: | 552 | if args.use_8bit_adam: |
556 | try: | 553 | try: |
@@ -783,7 +780,7 @@ def main(): | |||
783 | 780 | ||
784 | if args.find_lr: | 781 | if args.find_lr: |
785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 782 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
786 | lr_finder.run(num_train_steps=2) | 783 | lr_finder.run(min_lr=1e-6, num_train_batches=4) |
787 | 784 | ||
788 | plt.savefig(basepath.joinpath("lr.png")) | 785 | plt.savefig(basepath.joinpath("lr.png")) |
789 | plt.close() | 786 | plt.close() |
@@ -908,9 +905,8 @@ def main(): | |||
908 | avg_loss_val.update(loss.detach_(), bsz) | 905 | avg_loss_val.update(loss.detach_(), bsz) |
909 | avg_acc_val.update(acc.detach_(), bsz) | 906 | avg_acc_val.update(acc.detach_(), bsz) |
910 | 907 | ||
911 | if accelerator.sync_gradients: | 908 | local_progress_bar.update(1) |
912 | local_progress_bar.update(1) | 909 | global_progress_bar.update(1) |
913 | global_progress_bar.update(1) | ||
914 | 910 | ||
915 | logs = { | 911 | logs = { |
916 | "val/loss": avg_loss_val.avg.item(), | 912 | "val/loss": avg_loss_val.avg.item(), |
diff --git a/training/lr.py b/training/lr.py index 5343f24..8e558e1 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,3 +1,6 @@ | |||
1 | import math | ||
2 | import copy | ||
3 | |||
1 | import matplotlib.pyplot as plt | 4 | import matplotlib.pyplot as plt |
2 | import numpy as np | 5 | import numpy as np |
3 | import torch | 6 | import torch |
@@ -16,15 +19,22 @@ class LRFinder(): | |||
16 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
17 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
18 | 21 | ||
19 | def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
24 | |||
25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | ||
20 | best_loss = None | 26 | best_loss = None |
21 | lrs = [] | 27 | lrs = [] |
22 | losses = [] | 28 | losses = [] |
23 | 29 | ||
24 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
31 | |||
32 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
33 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
34 | steps *= num_epochs | ||
25 | 35 | ||
26 | progress_bar = tqdm( | 36 | progress_bar = tqdm( |
27 | range(num_epochs * (num_train_steps + num_val_steps)), | 37 | range(steps), |
28 | disable=not self.accelerator.is_local_main_process, | 38 | disable=not self.accelerator.is_local_main_process, |
29 | dynamic_ncols=True | 39 | dynamic_ncols=True |
30 | ) | 40 | ) |
@@ -38,6 +48,9 @@ class LRFinder(): | |||
38 | self.model.train() | 48 | self.model.train() |
39 | 49 | ||
40 | for step, batch in enumerate(self.train_dataloader): | 50 | for step, batch in enumerate(self.train_dataloader): |
51 | if step >= num_train_batches: | ||
52 | break | ||
53 | |||
41 | with self.accelerator.accumulate(self.model): | 54 | with self.accelerator.accumulate(self.model): |
42 | loss, acc, bsz = self.loss_fn(batch) | 55 | loss, acc, bsz = self.loss_fn(batch) |
43 | 56 | ||
@@ -49,21 +62,17 @@ class LRFinder(): | |||
49 | if self.accelerator.sync_gradients: | 62 | if self.accelerator.sync_gradients: |
50 | progress_bar.update(1) | 63 | progress_bar.update(1) |
51 | 64 | ||
52 | if step >= num_train_steps: | ||
53 | break | ||
54 | |||
55 | self.model.eval() | 65 | self.model.eval() |
56 | 66 | ||
57 | with torch.inference_mode(): | 67 | with torch.inference_mode(): |
58 | for step, batch in enumerate(self.val_dataloader): | 68 | for step, batch in enumerate(self.val_dataloader): |
69 | if step >= num_val_batches: | ||
70 | break | ||
71 | |||
59 | loss, acc, bsz = self.loss_fn(batch) | 72 | loss, acc, bsz = self.loss_fn(batch) |
60 | avg_loss.update(loss.detach_(), bsz) | 73 | avg_loss.update(loss.detach_(), bsz) |
61 | 74 | ||
62 | if self.accelerator.sync_gradients: | 75 | progress_bar.update(1) |
63 | progress_bar.update(1) | ||
64 | |||
65 | if step >= num_val_steps: | ||
66 | break | ||
67 | 76 | ||
68 | lr_scheduler.step() | 77 | lr_scheduler.step() |
69 | 78 | ||
@@ -87,6 +96,9 @@ class LRFinder(): | |||
87 | "lr": lr, | 96 | "lr": lr, |
88 | }) | 97 | }) |
89 | 98 | ||
99 | self.model.load_state_dict(self.model_state) | ||
100 | self.optimizer.load_state_dict(self.optimizer_state) | ||
101 | |||
90 | if loss > diverge_th * best_loss: | 102 | if loss > diverge_th * best_loss: |
91 | print("Stopping early, the loss has diverged") | 103 | print("Stopping early, the loss has diverged") |
92 | break | 104 | break |
@@ -120,8 +132,8 @@ class LRFinder(): | |||
120 | ax.set_ylabel("Loss") | 132 | ax.set_ylabel("Loss") |
121 | 133 | ||
122 | 134 | ||
123 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
124 | def lr_lambda(current_epoch: int): | 136 | def lr_lambda(current_epoch: int): |
125 | return (current_epoch / num_epochs) ** 5 | 137 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) |
126 | 138 | ||
127 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 139 | return LambdaLR(optimizer, lr_lambda, last_epoch) |