summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 11:48:33 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 11:48:33 +0100
commit30098b1d611853c0d3a4687d84582e1c1cf1b938 (patch)
tree94817d6ccd2fb7a8a58fb8a6ef6543b6db5b9a51 /training
parentAdded learning rate finder (diff)
downloadtextual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.tar.gz
textual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.tar.bz2
textual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.zip
Added validation phase to learn rate finder
Diffstat (limited to 'training')
-rw-r--r--training/lr.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/training/lr.py b/training/lr.py
index dd37baa..5343f24 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,20 +1,22 @@
1import matplotlib.pyplot as plt
1import numpy as np 2import numpy as np
3import torch
2from torch.optim.lr_scheduler import LambdaLR 4from torch.optim.lr_scheduler import LambdaLR
3from tqdm.auto import tqdm 5from tqdm.auto import tqdm
4import matplotlib.pyplot as plt
5 6
6from training.util import AverageMeter 7from training.util import AverageMeter
7 8
8 9
9class LRFinder(): 10class LRFinder():
10 def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): 11 def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn):
11 self.accelerator = accelerator 12 self.accelerator = accelerator
12 self.model = model 13 self.model = model
13 self.optimizer = optimizer 14 self.optimizer = optimizer
14 self.train_dataloader = train_dataloader 15 self.train_dataloader = train_dataloader
16 self.val_dataloader = val_dataloader
15 self.loss_fn = loss_fn 17 self.loss_fn = loss_fn
16 18
17 def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): 19 def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5):
18 best_loss = None 20 best_loss = None
19 lrs = [] 21 lrs = []
20 losses = [] 22 losses = []
@@ -22,7 +24,7 @@ class LRFinder():
22 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) 24 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs)
23 25
24 progress_bar = tqdm( 26 progress_bar = tqdm(
25 range(num_epochs * num_steps), 27 range(num_epochs * (num_train_steps + num_val_steps)),
26 disable=not self.accelerator.is_local_main_process, 28 disable=not self.accelerator.is_local_main_process,
27 dynamic_ncols=True 29 dynamic_ncols=True
28 ) 30 )
@@ -33,6 +35,8 @@ class LRFinder():
33 35
34 avg_loss = AverageMeter() 36 avg_loss = AverageMeter()
35 37
38 self.model.train()
39
36 for step, batch in enumerate(self.train_dataloader): 40 for step, batch in enumerate(self.train_dataloader):
37 with self.accelerator.accumulate(self.model): 41 with self.accelerator.accumulate(self.model):
38 loss, acc, bsz = self.loss_fn(batch) 42 loss, acc, bsz = self.loss_fn(batch)
@@ -42,13 +46,24 @@ class LRFinder():
42 self.optimizer.step() 46 self.optimizer.step()
43 self.optimizer.zero_grad(set_to_none=True) 47 self.optimizer.zero_grad(set_to_none=True)
44 48
45 avg_loss.update(loss.detach_(), bsz) 49 if self.accelerator.sync_gradients:
50 progress_bar.update(1)
46 51
47 if step >= num_steps: 52 if step >= num_train_steps:
48 break 53 break
49 54
50 if self.accelerator.sync_gradients: 55 self.model.eval()
51 progress_bar.update(1) 56
57 with torch.inference_mode():
58 for step, batch in enumerate(self.val_dataloader):
59 loss, acc, bsz = self.loss_fn(batch)
60 avg_loss.update(loss.detach_(), bsz)
61
62 if self.accelerator.sync_gradients:
63 progress_bar.update(1)
64
65 if step >= num_val_steps:
66 break
52 67
53 lr_scheduler.step() 68 lr_scheduler.step()
54 69
@@ -104,9 +119,6 @@ class LRFinder():
104 ax.set_xlabel("Learning rate") 119 ax.set_xlabel("Learning rate")
105 ax.set_ylabel("Loss") 120 ax.set_ylabel("Loss")
106 121
107 if fig is not None:
108 plt.show()
109
110 122
111def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): 123def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1):
112 def lr_lambda(current_epoch: int): 124 def lr_lambda(current_epoch: int):