summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
committerVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
commit83725794618164210a12843381724252fdd82cc2 (patch)
treeec29ade9891fe08dd10b5033214fc09237c2cb86 /training
parentImproved learning rate finder (diff)
downloadtextual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip
Integrated updates from diffusers
Diffstat (limited to 'training')
-rw-r--r--training/lr.py46
-rw-r--r--training/util.py5
2 files changed, 36 insertions, 15 deletions
diff --git a/training/lr.py b/training/lr.py
index 8e558e1..c1fa3a0 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -22,10 +22,13 @@ class LRFinder():
22 self.model_state = copy.deepcopy(model.state_dict()) 22 self.model_state = copy.deepcopy(model.state_dict())
23 self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 23 self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24 24
25 def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): 25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5):
26 best_loss = None 26 best_loss = None
27 best_acc = None
28
27 lrs = [] 29 lrs = []
28 losses = [] 30 losses = []
31 accs = []
29 32
30 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) 33 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs)
31 34
@@ -44,6 +47,7 @@ class LRFinder():
44 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
45 48
46 avg_loss = AverageMeter() 49 avg_loss = AverageMeter()
50 avg_acc = AverageMeter()
47 51
48 self.model.train() 52 self.model.train()
49 53
@@ -71,28 +75,37 @@ class LRFinder():
71 75
72 loss, acc, bsz = self.loss_fn(batch) 76 loss, acc, bsz = self.loss_fn(batch)
73 avg_loss.update(loss.detach_(), bsz) 77 avg_loss.update(loss.detach_(), bsz)
78 avg_acc.update(acc.detach_(), bsz)
74 79
75 progress_bar.update(1) 80 progress_bar.update(1)
76 81
77 lr_scheduler.step() 82 lr_scheduler.step()
78 83
79 loss = avg_loss.avg.item() 84 loss = avg_loss.avg.item()
85 acc = avg_acc.avg.item()
86
80 if epoch == 0: 87 if epoch == 0:
81 best_loss = loss 88 best_loss = loss
89 best_acc = acc
82 else: 90 else:
83 if smooth_f > 0: 91 if smooth_f > 0:
84 loss = smooth_f * loss + (1 - smooth_f) * losses[-1] 92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
85 if loss < best_loss: 93 if loss < best_loss:
86 best_loss = loss 94 best_loss = loss
95 if acc > best_acc:
96 best_acc = acc
87 97
88 lr = lr_scheduler.get_last_lr()[0] 98 lr = lr_scheduler.get_last_lr()[0]
89 99
90 lrs.append(lr) 100 lrs.append(lr)
91 losses.append(loss) 101 losses.append(loss)
102 accs.append(acc)
92 103
93 progress_bar.set_postfix({ 104 progress_bar.set_postfix({
94 "loss": loss, 105 "loss": loss,
95 "best": best_loss, 106 "loss/best": best_loss,
107 "acc": acc,
108 "acc/best": best_acc,
96 "lr": lr, 109 "lr": lr,
97 }) 110 })
98 111
@@ -103,20 +116,37 @@ class LRFinder():
103 print("Stopping early, the loss has diverged") 116 print("Stopping early, the loss has diverged")
104 break 117 break
105 118
106 fig, ax = plt.subplots() 119 if skip_end == 0:
107 ax.plot(lrs, losses) 120 lrs = lrs[skip_start:]
121 losses = losses[skip_start:]
122 accs = accs[skip_start:]
123 else:
124 lrs = lrs[skip_start:-skip_end]
125 losses = losses[skip_start:-skip_end]
126 accs = accs[skip_start:-skip_end]
127
128 fig, ax_loss = plt.subplots()
129
130 ax_loss.plot(lrs, losses, color='red', label='Loss')
131 ax_loss.set_xscale("log")
132 ax_loss.set_xlabel("Learning rate")
133
134 # ax_acc = ax_loss.twinx()
135 # ax_acc.plot(lrs, accs, color='blue', label='Accuracy')
108 136
109 print("LR suggestion: steepest gradient") 137 print("LR suggestion: steepest gradient")
110 min_grad_idx = None 138 min_grad_idx = None
139
111 try: 140 try:
112 min_grad_idx = (np.gradient(np.array(losses))).argmin() 141 min_grad_idx = (np.gradient(np.array(losses))).argmin()
113 except ValueError: 142 except ValueError:
114 print( 143 print(
115 "Failed to compute the gradients, there might not be enough points." 144 "Failed to compute the gradients, there might not be enough points."
116 ) 145 )
146
117 if min_grad_idx is not None: 147 if min_grad_idx is not None:
118 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) 148 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
119 ax.scatter( 149 ax_loss.scatter(
120 lrs[min_grad_idx], 150 lrs[min_grad_idx],
121 losses[min_grad_idx], 151 losses[min_grad_idx],
122 s=75, 152 s=75,
@@ -125,11 +155,7 @@ class LRFinder():
125 zorder=3, 155 zorder=3,
126 label="steepest gradient", 156 label="steepest gradient",
127 ) 157 )
128 ax.legend() 158 ax_loss.legend()
129
130 ax.set_xscale("log")
131 ax.set_xlabel("Learning rate")
132 ax.set_ylabel("Loss")
133 159
134 160
135def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): 161def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1):
diff --git a/training/util.py b/training/util.py
index a0c15cd..d0f7fcd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -5,11 +5,6 @@ import torch
5from PIL import Image 5from PIL import Image
6 6
7 7
8def freeze_params(params):
9 for param in params:
10 param.requires_grad = False
11
12
13def save_args(basepath: Path, args, extra={}): 8def save_args(basepath: Path, args, extra={}):
14 info = {"args": vars(args)} 9 info = {"args": vars(args)}
15 info["args"].update(extra) 10 info["args"].update(extra)