diff options
author | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
commit | 83725794618164210a12843381724252fdd82cc2 (patch) | |
tree | ec29ade9891fe08dd10b5033214fc09237c2cb86 /training | |
parent | Improved learning rate finder (diff) | |
download | textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2 textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip |
Integrated updates from diffusers
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 46 | ||||
-rw-r--r-- | training/util.py | 5 |
2 files changed, 36 insertions, 15 deletions
diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -22,10 +22,13 @@ class LRFinder(): | |||
22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
24 | 24 | ||
25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
26 | best_loss = None | 26 | best_loss = None |
27 | best_acc = None | ||
28 | |||
27 | lrs = [] | 29 | lrs = [] |
28 | losses = [] | 30 | losses = [] |
31 | accs = [] | ||
29 | 32 | ||
30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 33 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
31 | 34 | ||
@@ -44,6 +47,7 @@ class LRFinder(): | |||
44 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
45 | 48 | ||
46 | avg_loss = AverageMeter() | 49 | avg_loss = AverageMeter() |
50 | avg_acc = AverageMeter() | ||
47 | 51 | ||
48 | self.model.train() | 52 | self.model.train() |
49 | 53 | ||
@@ -71,28 +75,37 @@ class LRFinder(): | |||
71 | 75 | ||
72 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
73 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
78 | avg_acc.update(acc.detach_(), bsz) | ||
74 | 79 | ||
75 | progress_bar.update(1) | 80 | progress_bar.update(1) |
76 | 81 | ||
77 | lr_scheduler.step() | 82 | lr_scheduler.step() |
78 | 83 | ||
79 | loss = avg_loss.avg.item() | 84 | loss = avg_loss.avg.item() |
85 | acc = avg_acc.avg.item() | ||
86 | |||
80 | if epoch == 0: | 87 | if epoch == 0: |
81 | best_loss = loss | 88 | best_loss = loss |
89 | best_acc = acc | ||
82 | else: | 90 | else: |
83 | if smooth_f > 0: | 91 | if smooth_f > 0: |
84 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
85 | if loss < best_loss: | 93 | if loss < best_loss: |
86 | best_loss = loss | 94 | best_loss = loss |
95 | if acc > best_acc: | ||
96 | best_acc = acc | ||
87 | 97 | ||
88 | lr = lr_scheduler.get_last_lr()[0] | 98 | lr = lr_scheduler.get_last_lr()[0] |
89 | 99 | ||
90 | lrs.append(lr) | 100 | lrs.append(lr) |
91 | losses.append(loss) | 101 | losses.append(loss) |
102 | accs.append(acc) | ||
92 | 103 | ||
93 | progress_bar.set_postfix({ | 104 | progress_bar.set_postfix({ |
94 | "loss": loss, | 105 | "loss": loss, |
95 | "best": best_loss, | 106 | "loss/best": best_loss, |
107 | "acc": acc, | ||
108 | "acc/best": best_acc, | ||
96 | "lr": lr, | 109 | "lr": lr, |
97 | }) | 110 | }) |
98 | 111 | ||
@@ -103,20 +116,37 @@ class LRFinder(): | |||
103 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
104 | break | 117 | break |
105 | 118 | ||
106 | fig, ax = plt.subplots() | 119 | if skip_end == 0: |
107 | ax.plot(lrs, losses) | 120 | lrs = lrs[skip_start:] |
121 | losses = losses[skip_start:] | ||
122 | accs = accs[skip_start:] | ||
123 | else: | ||
124 | lrs = lrs[skip_start:-skip_end] | ||
125 | losses = losses[skip_start:-skip_end] | ||
126 | accs = accs[skip_start:-skip_end] | ||
127 | |||
128 | fig, ax_loss = plt.subplots() | ||
129 | |||
130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | ||
131 | ax_loss.set_xscale("log") | ||
132 | ax_loss.set_xlabel("Learning rate") | ||
133 | |||
134 | # ax_acc = ax_loss.twinx() | ||
135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | ||
108 | 136 | ||
109 | print("LR suggestion: steepest gradient") | 137 | print("LR suggestion: steepest gradient") |
110 | min_grad_idx = None | 138 | min_grad_idx = None |
139 | |||
111 | try: | 140 | try: |
112 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 141 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
113 | except ValueError: | 142 | except ValueError: |
114 | print( | 143 | print( |
115 | "Failed to compute the gradients, there might not be enough points." | 144 | "Failed to compute the gradients, there might not be enough points." |
116 | ) | 145 | ) |
146 | |||
117 | if min_grad_idx is not None: | 147 | if min_grad_idx is not None: |
118 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 148 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) |
119 | ax.scatter( | 149 | ax_loss.scatter( |
120 | lrs[min_grad_idx], | 150 | lrs[min_grad_idx], |
121 | losses[min_grad_idx], | 151 | losses[min_grad_idx], |
122 | s=75, | 152 | s=75, |
@@ -125,11 +155,7 @@ class LRFinder(): | |||
125 | zorder=3, | 155 | zorder=3, |
126 | label="steepest gradient", | 156 | label="steepest gradient", |
127 | ) | 157 | ) |
128 | ax.legend() | 158 | ax_loss.legend() |
129 | |||
130 | ax.set_xscale("log") | ||
131 | ax.set_xlabel("Learning rate") | ||
132 | ax.set_ylabel("Loss") | ||
133 | 159 | ||
134 | 160 | ||
135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 161 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -5,11 +5,6 @@ import torch | |||
5 | from PIL import Image | 5 | from PIL import Image |
6 | 6 | ||
7 | 7 | ||
8 | def freeze_params(params): | ||
9 | for param in params: | ||
10 | param.requires_grad = False | ||
11 | |||
12 | |||
13 | def save_args(basepath: Path, args, extra={}): | 8 | def save_args(basepath: Path, args, extra={}): |
14 | info = {"args": vars(args)} | 9 | info = {"args": vars(args)} |
15 | info["args"].update(extra) | 10 | info["args"].update(extra) |