summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
commit6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch)
tree2ebac26cb0fd377a95437ee54b517011fed36eac
parentAdded validation phase to learn rate finder (diff)
downloadtextual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip
Improved learning rate finder
-rw-r--r--train_dreambooth.py5
-rw-r--r--train_ti.py10
-rw-r--r--training/lr.py38
3 files changed, 30 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a62cec9..325fe90 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -970,9 +970,8 @@ def main():
970 avg_loss_val.update(loss.detach_(), bsz) 970 avg_loss_val.update(loss.detach_(), bsz)
971 avg_acc_val.update(acc.detach_(), bsz) 971 avg_acc_val.update(acc.detach_(), bsz)
972 972
973 if accelerator.sync_gradients: 973 local_progress_bar.update(1)
974 local_progress_bar.update(1) 974 global_progress_bar.update(1)
975 global_progress_bar.update(1)
976 975
977 logs = { 976 logs = {
978 "val/loss": avg_loss_val.avg.item(), 977 "val/loss": avg_loss_val.avg.item(),
diff --git a/train_ti.py b/train_ti.py
index 32f44f4..870b2ba 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -548,9 +548,6 @@ def main():
548 args.train_batch_size * accelerator.num_processes 548 args.train_batch_size * accelerator.num_processes
549 ) 549 )
550 550
551 if args.find_lr:
552 args.learning_rate = 1e2
553
554 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 551 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
555 if args.use_8bit_adam: 552 if args.use_8bit_adam:
556 try: 553 try:
@@ -783,7 +780,7 @@ def main():
783 780
784 if args.find_lr: 781 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 782 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
786 lr_finder.run(num_train_steps=2) 783 lr_finder.run(min_lr=1e-6, num_train_batches=4)
787 784
788 plt.savefig(basepath.joinpath("lr.png")) 785 plt.savefig(basepath.joinpath("lr.png"))
789 plt.close() 786 plt.close()
@@ -908,9 +905,8 @@ def main():
908 avg_loss_val.update(loss.detach_(), bsz) 905 avg_loss_val.update(loss.detach_(), bsz)
909 avg_acc_val.update(acc.detach_(), bsz) 906 avg_acc_val.update(acc.detach_(), bsz)
910 907
911 if accelerator.sync_gradients: 908 local_progress_bar.update(1)
912 local_progress_bar.update(1) 909 global_progress_bar.update(1)
913 global_progress_bar.update(1)
914 910
915 logs = { 911 logs = {
916 "val/loss": avg_loss_val.avg.item(), 912 "val/loss": avg_loss_val.avg.item(),
diff --git a/training/lr.py b/training/lr.py
index 5343f24..8e558e1 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,3 +1,6 @@
1import math
2import copy
3
1import matplotlib.pyplot as plt 4import matplotlib.pyplot as plt
2import numpy as np 5import numpy as np
3import torch 6import torch
@@ -16,15 +19,22 @@ class LRFinder():
16 self.val_dataloader = val_dataloader 19 self.val_dataloader = val_dataloader
17 self.loss_fn = loss_fn 20 self.loss_fn = loss_fn
18 21
19 def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): 22 self.model_state = copy.deepcopy(model.state_dict())
23 self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24
25 def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5):
20 best_loss = None 26 best_loss = None
21 lrs = [] 27 lrs = []
22 losses = [] 28 losses = []
23 29
24 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) 30 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs)
31
32 steps = min(num_train_batches, len(self.train_dataloader))
33 steps += min(num_val_batches, len(self.val_dataloader))
34 steps *= num_epochs
25 35
26 progress_bar = tqdm( 36 progress_bar = tqdm(
27 range(num_epochs * (num_train_steps + num_val_steps)), 37 range(steps),
28 disable=not self.accelerator.is_local_main_process, 38 disable=not self.accelerator.is_local_main_process,
29 dynamic_ncols=True 39 dynamic_ncols=True
30 ) 40 )
@@ -38,6 +48,9 @@ class LRFinder():
38 self.model.train() 48 self.model.train()
39 49
40 for step, batch in enumerate(self.train_dataloader): 50 for step, batch in enumerate(self.train_dataloader):
51 if step >= num_train_batches:
52 break
53
41 with self.accelerator.accumulate(self.model): 54 with self.accelerator.accumulate(self.model):
42 loss, acc, bsz = self.loss_fn(batch) 55 loss, acc, bsz = self.loss_fn(batch)
43 56
@@ -49,21 +62,17 @@ class LRFinder():
49 if self.accelerator.sync_gradients: 62 if self.accelerator.sync_gradients:
50 progress_bar.update(1) 63 progress_bar.update(1)
51 64
52 if step >= num_train_steps:
53 break
54
55 self.model.eval() 65 self.model.eval()
56 66
57 with torch.inference_mode(): 67 with torch.inference_mode():
58 for step, batch in enumerate(self.val_dataloader): 68 for step, batch in enumerate(self.val_dataloader):
69 if step >= num_val_batches:
70 break
71
59 loss, acc, bsz = self.loss_fn(batch) 72 loss, acc, bsz = self.loss_fn(batch)
60 avg_loss.update(loss.detach_(), bsz) 73 avg_loss.update(loss.detach_(), bsz)
61 74
62 if self.accelerator.sync_gradients: 75 progress_bar.update(1)
63 progress_bar.update(1)
64
65 if step >= num_val_steps:
66 break
67 76
68 lr_scheduler.step() 77 lr_scheduler.step()
69 78
@@ -87,6 +96,9 @@ class LRFinder():
87 "lr": lr, 96 "lr": lr,
88 }) 97 })
89 98
99 self.model.load_state_dict(self.model_state)
100 self.optimizer.load_state_dict(self.optimizer_state)
101
90 if loss > diverge_th * best_loss: 102 if loss > diverge_th * best_loss:
91 print("Stopping early, the loss has diverged") 103 print("Stopping early, the loss has diverged")
92 break 104 break
@@ -120,8 +132,8 @@ class LRFinder():
120 ax.set_ylabel("Loss") 132 ax.set_ylabel("Loss")
121 133
122 134
123def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): 135def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1):
124 def lr_lambda(current_epoch: int): 136 def lr_lambda(current_epoch: int):
125 return (current_epoch / num_epochs) ** 5 137 return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr)
126 138
127 return LambdaLR(optimizer, lr_lambda, last_epoch) 139 return LambdaLR(optimizer, lr_lambda, last_epoch)