summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
commit6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch)
tree2ebac26cb0fd377a95437ee54b517011fed36eac /training
parentAdded validation phase to learn rate finder (diff)
downloadtextual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip
Improved learning rate finder
Diffstat (limited to 'training')
-rw-r--r--training/lr.py38
1 files changed, 25 insertions, 13 deletions
diff --git a/training/lr.py b/training/lr.py
index 5343f24..8e558e1 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,3 +1,6 @@
1import math
2import copy
3
1import matplotlib.pyplot as plt 4import matplotlib.pyplot as plt
2import numpy as np 5import numpy as np
3import torch 6import torch
@@ -16,15 +19,22 @@ class LRFinder():
16 self.val_dataloader = val_dataloader 19 self.val_dataloader = val_dataloader
17 self.loss_fn = loss_fn 20 self.loss_fn = loss_fn
18 21
19 def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): 22 self.model_state = copy.deepcopy(model.state_dict())
23 self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24
25 def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5):
20 best_loss = None 26 best_loss = None
21 lrs = [] 27 lrs = []
22 losses = [] 28 losses = []
23 29
24 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) 30 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs)
31
32 steps = min(num_train_batches, len(self.train_dataloader))
33 steps += min(num_val_batches, len(self.val_dataloader))
34 steps *= num_epochs
25 35
26 progress_bar = tqdm( 36 progress_bar = tqdm(
27 range(num_epochs * (num_train_steps + num_val_steps)), 37 range(steps),
28 disable=not self.accelerator.is_local_main_process, 38 disable=not self.accelerator.is_local_main_process,
29 dynamic_ncols=True 39 dynamic_ncols=True
30 ) 40 )
@@ -38,6 +48,9 @@ class LRFinder():
38 self.model.train() 48 self.model.train()
39 49
40 for step, batch in enumerate(self.train_dataloader): 50 for step, batch in enumerate(self.train_dataloader):
51 if step >= num_train_batches:
52 break
53
41 with self.accelerator.accumulate(self.model): 54 with self.accelerator.accumulate(self.model):
42 loss, acc, bsz = self.loss_fn(batch) 55 loss, acc, bsz = self.loss_fn(batch)
43 56
@@ -49,21 +62,17 @@ class LRFinder():
49 if self.accelerator.sync_gradients: 62 if self.accelerator.sync_gradients:
50 progress_bar.update(1) 63 progress_bar.update(1)
51 64
52 if step >= num_train_steps:
53 break
54
55 self.model.eval() 65 self.model.eval()
56 66
57 with torch.inference_mode(): 67 with torch.inference_mode():
58 for step, batch in enumerate(self.val_dataloader): 68 for step, batch in enumerate(self.val_dataloader):
69 if step >= num_val_batches:
70 break
71
59 loss, acc, bsz = self.loss_fn(batch) 72 loss, acc, bsz = self.loss_fn(batch)
60 avg_loss.update(loss.detach_(), bsz) 73 avg_loss.update(loss.detach_(), bsz)
61 74
62 if self.accelerator.sync_gradients: 75 progress_bar.update(1)
63 progress_bar.update(1)
64
65 if step >= num_val_steps:
66 break
67 76
68 lr_scheduler.step() 77 lr_scheduler.step()
69 78
@@ -87,6 +96,9 @@ class LRFinder():
87 "lr": lr, 96 "lr": lr,
88 }) 97 })
89 98
99 self.model.load_state_dict(self.model_state)
100 self.optimizer.load_state_dict(self.optimizer_state)
101
90 if loss > diverge_th * best_loss: 102 if loss > diverge_th * best_loss:
91 print("Stopping early, the loss has diverged") 103 print("Stopping early, the loss has diverged")
92 break 104 break
@@ -120,8 +132,8 @@ class LRFinder():
120 ax.set_ylabel("Loss") 132 ax.set_ylabel("Loss")
121 133
122 134
123def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): 135def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1):
124 def lr_lambda(current_epoch: int): 136 def lr_lambda(current_epoch: int):
125 return (current_epoch / num_epochs) ** 5 137 return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr)
126 138
127 return LambdaLR(optimizer, lr_lambda, last_epoch) 139 return LambdaLR(optimizer, lr_lambda, last_epoch)