diff options
| -rw-r--r-- | train_dreambooth.py | 5 | ||||
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | training/lr.py | 38 |
3 files changed, 30 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a62cec9..325fe90 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -970,9 +970,8 @@ def main(): | |||
| 970 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
| 971 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
| 972 | 972 | ||
| 973 | if accelerator.sync_gradients: | 973 | local_progress_bar.update(1) |
| 974 | local_progress_bar.update(1) | 974 | global_progress_bar.update(1) |
| 975 | global_progress_bar.update(1) | ||
| 976 | 975 | ||
| 977 | logs = { | 976 | logs = { |
| 978 | "val/loss": avg_loss_val.avg.item(), | 977 | "val/loss": avg_loss_val.avg.item(), |
diff --git a/train_ti.py b/train_ti.py index 32f44f4..870b2ba 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -548,9 +548,6 @@ def main(): | |||
| 548 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
| 549 | ) | 549 | ) |
| 550 | 550 | ||
| 551 | if args.find_lr: | ||
| 552 | args.learning_rate = 1e2 | ||
| 553 | |||
| 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 555 | if args.use_8bit_adam: | 552 | if args.use_8bit_adam: |
| 556 | try: | 553 | try: |
| @@ -783,7 +780,7 @@ def main(): | |||
| 783 | 780 | ||
| 784 | if args.find_lr: | 781 | if args.find_lr: |
| 785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 782 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 786 | lr_finder.run(num_train_steps=2) | 783 | lr_finder.run(min_lr=1e-6, num_train_batches=4) |
| 787 | 784 | ||
| 788 | plt.savefig(basepath.joinpath("lr.png")) | 785 | plt.savefig(basepath.joinpath("lr.png")) |
| 789 | plt.close() | 786 | plt.close() |
| @@ -908,9 +905,8 @@ def main(): | |||
| 908 | avg_loss_val.update(loss.detach_(), bsz) | 905 | avg_loss_val.update(loss.detach_(), bsz) |
| 909 | avg_acc_val.update(acc.detach_(), bsz) | 906 | avg_acc_val.update(acc.detach_(), bsz) |
| 910 | 907 | ||
| 911 | if accelerator.sync_gradients: | 908 | local_progress_bar.update(1) |
| 912 | local_progress_bar.update(1) | 909 | global_progress_bar.update(1) |
| 913 | global_progress_bar.update(1) | ||
| 914 | 910 | ||
| 915 | logs = { | 911 | logs = { |
| 916 | "val/loss": avg_loss_val.avg.item(), | 912 | "val/loss": avg_loss_val.avg.item(), |
diff --git a/training/lr.py b/training/lr.py index 5343f24..8e558e1 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,3 +1,6 @@ | |||
| 1 | import math | ||
| 2 | import copy | ||
| 3 | |||
| 1 | import matplotlib.pyplot as plt | 4 | import matplotlib.pyplot as plt |
| 2 | import numpy as np | 5 | import numpy as np |
| 3 | import torch | 6 | import torch |
| @@ -16,15 +19,22 @@ class LRFinder(): | |||
| 16 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
| 17 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
| 18 | 21 | ||
| 19 | def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
| 24 | |||
| 25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | ||
| 20 | best_loss = None | 26 | best_loss = None |
| 21 | lrs = [] | 27 | lrs = [] |
| 22 | losses = [] | 28 | losses = [] |
| 23 | 29 | ||
| 24 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
| 31 | |||
| 32 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
| 33 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
| 34 | steps *= num_epochs | ||
| 25 | 35 | ||
| 26 | progress_bar = tqdm( | 36 | progress_bar = tqdm( |
| 27 | range(num_epochs * (num_train_steps + num_val_steps)), | 37 | range(steps), |
| 28 | disable=not self.accelerator.is_local_main_process, | 38 | disable=not self.accelerator.is_local_main_process, |
| 29 | dynamic_ncols=True | 39 | dynamic_ncols=True |
| 30 | ) | 40 | ) |
| @@ -38,6 +48,9 @@ class LRFinder(): | |||
| 38 | self.model.train() | 48 | self.model.train() |
| 39 | 49 | ||
| 40 | for step, batch in enumerate(self.train_dataloader): | 50 | for step, batch in enumerate(self.train_dataloader): |
| 51 | if step >= num_train_batches: | ||
| 52 | break | ||
| 53 | |||
| 41 | with self.accelerator.accumulate(self.model): | 54 | with self.accelerator.accumulate(self.model): |
| 42 | loss, acc, bsz = self.loss_fn(batch) | 55 | loss, acc, bsz = self.loss_fn(batch) |
| 43 | 56 | ||
| @@ -49,21 +62,17 @@ class LRFinder(): | |||
| 49 | if self.accelerator.sync_gradients: | 62 | if self.accelerator.sync_gradients: |
| 50 | progress_bar.update(1) | 63 | progress_bar.update(1) |
| 51 | 64 | ||
| 52 | if step >= num_train_steps: | ||
| 53 | break | ||
| 54 | |||
| 55 | self.model.eval() | 65 | self.model.eval() |
| 56 | 66 | ||
| 57 | with torch.inference_mode(): | 67 | with torch.inference_mode(): |
| 58 | for step, batch in enumerate(self.val_dataloader): | 68 | for step, batch in enumerate(self.val_dataloader): |
| 69 | if step >= num_val_batches: | ||
| 70 | break | ||
| 71 | |||
| 59 | loss, acc, bsz = self.loss_fn(batch) | 72 | loss, acc, bsz = self.loss_fn(batch) |
| 60 | avg_loss.update(loss.detach_(), bsz) | 73 | avg_loss.update(loss.detach_(), bsz) |
| 61 | 74 | ||
| 62 | if self.accelerator.sync_gradients: | 75 | progress_bar.update(1) |
| 63 | progress_bar.update(1) | ||
| 64 | |||
| 65 | if step >= num_val_steps: | ||
| 66 | break | ||
| 67 | 76 | ||
| 68 | lr_scheduler.step() | 77 | lr_scheduler.step() |
| 69 | 78 | ||
| @@ -87,6 +96,9 @@ class LRFinder(): | |||
| 87 | "lr": lr, | 96 | "lr": lr, |
| 88 | }) | 97 | }) |
| 89 | 98 | ||
| 99 | self.model.load_state_dict(self.model_state) | ||
| 100 | self.optimizer.load_state_dict(self.optimizer_state) | ||
| 101 | |||
| 90 | if loss > diverge_th * best_loss: | 102 | if loss > diverge_th * best_loss: |
| 91 | print("Stopping early, the loss has diverged") | 103 | print("Stopping early, the loss has diverged") |
| 92 | break | 104 | break |
| @@ -120,8 +132,8 @@ class LRFinder(): | |||
| 120 | ax.set_ylabel("Loss") | 132 | ax.set_ylabel("Loss") |
| 121 | 133 | ||
| 122 | 134 | ||
| 123 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
| 124 | def lr_lambda(current_epoch: int): | 136 | def lr_lambda(current_epoch: int): |
| 125 | return (current_epoch / num_epochs) ** 5 | 137 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) |
| 126 | 138 | ||
| 127 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 139 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
