summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
commit67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch)
treee1308417bde00609a5347bc39a8cd6583fd066f8 /training
parentFix (diff)
downloadtextual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/lr.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/training/lr.py b/training/lr.py
index 3abd2f2..fe166ed 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,5 +1,6 @@
1import math 1import math
2import copy 2import copy
3from typing import Callable
3 4
4import matplotlib.pyplot as plt 5import matplotlib.pyplot as plt
5import numpy as np 6import numpy as np
@@ -10,19 +11,45 @@ from tqdm.auto import tqdm
10from training.util import AverageMeter 11from training.util import AverageMeter
11 12
12 13
14def noop():
15 pass
16
17
13class LRFinder(): 18class LRFinder():
14 def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): 19 def __init__(
20 self,
21 accelerator,
22 model,
23 optimizer,
24 train_dataloader,
25 val_dataloader,
26 loss_fn,
27 on_train: Callable[[], None] = noop,
28 on_eval: Callable[[], None] = noop
29 ):
15 self.accelerator = accelerator 30 self.accelerator = accelerator
16 self.model = model 31 self.model = model
17 self.optimizer = optimizer 32 self.optimizer = optimizer
18 self.train_dataloader = train_dataloader 33 self.train_dataloader = train_dataloader
19 self.val_dataloader = val_dataloader 34 self.val_dataloader = val_dataloader
20 self.loss_fn = loss_fn 35 self.loss_fn = loss_fn
36 self.on_train = on_train
37 self.on_eval = on_eval
21 38
22 # self.model_state = copy.deepcopy(model.state_dict()) 39 # self.model_state = copy.deepcopy(model.state_dict())
23 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 40 # self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24 41
25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): 42 def run(
43 self,
44 min_lr,
45 skip_start: int = 10,
46 skip_end: int = 5,
47 num_epochs: int = 100,
48 num_train_batches: int = 1,
49 num_val_batches: int = math.inf,
50 smooth_f: float = 0.05,
51 diverge_th: int = 5
52 ):
26 best_loss = None 53 best_loss = None
27 best_acc = None 54 best_acc = None
28 55
@@ -50,6 +77,7 @@ class LRFinder():
50 avg_acc = AverageMeter() 77 avg_acc = AverageMeter()
51 78
52 self.model.train() 79 self.model.train()
80 self.on_train()
53 81
54 for step, batch in enumerate(self.train_dataloader): 82 for step, batch in enumerate(self.train_dataloader):
55 if step >= num_train_batches: 83 if step >= num_train_batches:
@@ -67,6 +95,7 @@ class LRFinder():
67 progress_bar.update(1) 95 progress_bar.update(1)
68 96
69 self.model.eval() 97 self.model.eval()
98 self.on_eval()
70 99
71 with torch.inference_mode(): 100 with torch.inference_mode():
72 for step, batch in enumerate(self.val_dataloader): 101 for step, batch in enumerate(self.val_dataloader):