diff options
author | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
commit | 6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch) | |
tree | 2ebac26cb0fd377a95437ee54b517011fed36eac /training | |
parent | Added validation phase to learn rate finder (diff) | |
download | textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2 textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip |
Improved learning rate finder
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/training/lr.py b/training/lr.py index 5343f24..8e558e1 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,3 +1,6 @@ | |||
1 | import math | ||
2 | import copy | ||
3 | |||
1 | import matplotlib.pyplot as plt | 4 | import matplotlib.pyplot as plt |
2 | import numpy as np | 5 | import numpy as np |
3 | import torch | 6 | import torch |
@@ -16,15 +19,22 @@ class LRFinder(): | |||
16 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
17 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
18 | 21 | ||
19 | def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
24 | |||
25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | ||
20 | best_loss = None | 26 | best_loss = None |
21 | lrs = [] | 27 | lrs = [] |
22 | losses = [] | 28 | losses = [] |
23 | 29 | ||
24 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
31 | |||
32 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
33 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
34 | steps *= num_epochs | ||
25 | 35 | ||
26 | progress_bar = tqdm( | 36 | progress_bar = tqdm( |
27 | range(num_epochs * (num_train_steps + num_val_steps)), | 37 | range(steps), |
28 | disable=not self.accelerator.is_local_main_process, | 38 | disable=not self.accelerator.is_local_main_process, |
29 | dynamic_ncols=True | 39 | dynamic_ncols=True |
30 | ) | 40 | ) |
@@ -38,6 +48,9 @@ class LRFinder(): | |||
38 | self.model.train() | 48 | self.model.train() |
39 | 49 | ||
40 | for step, batch in enumerate(self.train_dataloader): | 50 | for step, batch in enumerate(self.train_dataloader): |
51 | if step >= num_train_batches: | ||
52 | break | ||
53 | |||
41 | with self.accelerator.accumulate(self.model): | 54 | with self.accelerator.accumulate(self.model): |
42 | loss, acc, bsz = self.loss_fn(batch) | 55 | loss, acc, bsz = self.loss_fn(batch) |
43 | 56 | ||
@@ -49,21 +62,17 @@ class LRFinder(): | |||
49 | if self.accelerator.sync_gradients: | 62 | if self.accelerator.sync_gradients: |
50 | progress_bar.update(1) | 63 | progress_bar.update(1) |
51 | 64 | ||
52 | if step >= num_train_steps: | ||
53 | break | ||
54 | |||
55 | self.model.eval() | 65 | self.model.eval() |
56 | 66 | ||
57 | with torch.inference_mode(): | 67 | with torch.inference_mode(): |
58 | for step, batch in enumerate(self.val_dataloader): | 68 | for step, batch in enumerate(self.val_dataloader): |
69 | if step >= num_val_batches: | ||
70 | break | ||
71 | |||
59 | loss, acc, bsz = self.loss_fn(batch) | 72 | loss, acc, bsz = self.loss_fn(batch) |
60 | avg_loss.update(loss.detach_(), bsz) | 73 | avg_loss.update(loss.detach_(), bsz) |
61 | 74 | ||
62 | if self.accelerator.sync_gradients: | 75 | progress_bar.update(1) |
63 | progress_bar.update(1) | ||
64 | |||
65 | if step >= num_val_steps: | ||
66 | break | ||
67 | 76 | ||
68 | lr_scheduler.step() | 77 | lr_scheduler.step() |
69 | 78 | ||
@@ -87,6 +96,9 @@ class LRFinder(): | |||
87 | "lr": lr, | 96 | "lr": lr, |
88 | }) | 97 | }) |
89 | 98 | ||
99 | self.model.load_state_dict(self.model_state) | ||
100 | self.optimizer.load_state_dict(self.optimizer_state) | ||
101 | |||
90 | if loss > diverge_th * best_loss: | 102 | if loss > diverge_th * best_loss: |
91 | print("Stopping early, the loss has diverged") | 103 | print("Stopping early, the loss has diverged") |
92 | break | 104 | break |
@@ -120,8 +132,8 @@ class LRFinder(): | |||
120 | ax.set_ylabel("Loss") | 132 | ax.set_ylabel("Loss") |
121 | 133 | ||
122 | 134 | ||
123 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
124 | def lr_lambda(current_epoch: int): | 136 | def lr_lambda(current_epoch: int): |
125 | return (current_epoch / num_epochs) ** 5 | 137 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) |
126 | 138 | ||
127 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 139 | return LambdaLR(optimizer, lr_lambda, last_epoch) |