diff options
author | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
commit | 67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch) | |
tree | e1308417bde00609a5347bc39a8cd6583fd066f8 /training | |
parent | Fix (diff) | |
download | textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2 textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,5 +1,6 @@ | |||
1 | import math | 1 | import math |
2 | import copy | 2 | import copy |
3 | from typing import Callable | ||
3 | 4 | ||
4 | import matplotlib.pyplot as plt | 5 | import matplotlib.pyplot as plt |
5 | import numpy as np | 6 | import numpy as np |
@@ -10,19 +11,45 @@ from tqdm.auto import tqdm | |||
10 | from training.util import AverageMeter | 11 | from training.util import AverageMeter |
11 | 12 | ||
12 | 13 | ||
14 | def noop(): | ||
15 | pass | ||
16 | |||
17 | |||
13 | class LRFinder(): | 18 | class LRFinder(): |
14 | def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): | 19 | def __init__( |
20 | self, | ||
21 | accelerator, | ||
22 | model, | ||
23 | optimizer, | ||
24 | train_dataloader, | ||
25 | val_dataloader, | ||
26 | loss_fn, | ||
27 | on_train: Callable[[], None] = noop, | ||
28 | on_eval: Callable[[], None] = noop | ||
29 | ): | ||
15 | self.accelerator = accelerator | 30 | self.accelerator = accelerator |
16 | self.model = model | 31 | self.model = model |
17 | self.optimizer = optimizer | 32 | self.optimizer = optimizer |
18 | self.train_dataloader = train_dataloader | 33 | self.train_dataloader = train_dataloader |
19 | self.val_dataloader = val_dataloader | 34 | self.val_dataloader = val_dataloader |
20 | self.loss_fn = loss_fn | 35 | self.loss_fn = loss_fn |
36 | self.on_train = on_train | ||
37 | self.on_eval = on_eval | ||
21 | 38 | ||
22 | # self.model_state = copy.deepcopy(model.state_dict()) | 39 | # self.model_state = copy.deepcopy(model.state_dict()) |
23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 40 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
24 | 41 | ||
25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 42 | def run( |
43 | self, | ||
44 | min_lr, | ||
45 | skip_start: int = 10, | ||
46 | skip_end: int = 5, | ||
47 | num_epochs: int = 100, | ||
48 | num_train_batches: int = 1, | ||
49 | num_val_batches: int = math.inf, | ||
50 | smooth_f: float = 0.05, | ||
51 | diverge_th: int = 5 | ||
52 | ): | ||
26 | best_loss = None | 53 | best_loss = None |
27 | best_acc = None | 54 | best_acc = None |
28 | 55 | ||
@@ -50,6 +77,7 @@ class LRFinder(): | |||
50 | avg_acc = AverageMeter() | 77 | avg_acc = AverageMeter() |
51 | 78 | ||
52 | self.model.train() | 79 | self.model.train() |
80 | self.on_train() | ||
53 | 81 | ||
54 | for step, batch in enumerate(self.train_dataloader): | 82 | for step, batch in enumerate(self.train_dataloader): |
55 | if step >= num_train_batches: | 83 | if step >= num_train_batches: |
@@ -67,6 +95,7 @@ class LRFinder(): | |||
67 | progress_bar.update(1) | 95 | progress_bar.update(1) |
68 | 96 | ||
69 | self.model.eval() | 97 | self.model.eval() |
98 | self.on_eval() | ||
70 | 99 | ||
71 | with torch.inference_mode(): | 100 | with torch.inference_mode(): |
72 | for step, batch in enumerate(self.val_dataloader): | 101 | for step, batch in enumerate(self.val_dataloader): |