diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 12:33:52 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 12:33:52 +0100 |
commit | 59bf501198d7ff6c0c03c45e92adef14069d5ac6 (patch) | |
tree | aae4c7204b4f04bf2146408fb88892071840a05d /training/lr.py | |
parent | Removed unused code, put training callbacks in dataclass (diff) | |
download | textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.gz textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.bz2 textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.zip |
Update
Diffstat (limited to 'training/lr.py')
-rw-r--r-- | training/lr.py | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -9,6 +9,7 @@ import torch | |||
9 | from torch.optim.lr_scheduler import LambdaLR | 9 | from torch.optim.lr_scheduler import LambdaLR |
10 | from tqdm.auto import tqdm | 10 | from tqdm.auto import tqdm |
11 | 11 | ||
12 | from training.functional import TrainingCallbacks | ||
12 | from training.util import AverageMeter | 13 | from training.util import AverageMeter |
13 | 14 | ||
14 | 15 | ||
@@ -24,26 +25,19 @@ class LRFinder(): | |||
24 | def __init__( | 25 | def __init__( |
25 | self, | 26 | self, |
26 | accelerator, | 27 | accelerator, |
27 | model, | ||
28 | optimizer, | 28 | optimizer, |
29 | train_dataloader, | 29 | train_dataloader, |
30 | val_dataloader, | 30 | val_dataloader, |
31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
32 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, | 32 | callbacks: TrainingCallbacks = TrainingCallbacks() |
33 | on_before_optimize: Callable[[int], None] = noop, | ||
34 | on_after_optimize: Callable[[float], None] = noop, | ||
35 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx | ||
36 | ): | 33 | ): |
37 | self.accelerator = accelerator | 34 | self.accelerator = accelerator |
38 | self.model = model | 35 | self.model = callbacks.on_model() |
39 | self.optimizer = optimizer | 36 | self.optimizer = optimizer |
40 | self.train_dataloader = train_dataloader | 37 | self.train_dataloader = train_dataloader |
41 | self.val_dataloader = val_dataloader | 38 | self.val_dataloader = val_dataloader |
42 | self.loss_fn = loss_fn | 39 | self.loss_fn = loss_fn |
43 | self.on_train = on_train | 40 | self.callbacks = callbacks |
44 | self.on_before_optimize = on_before_optimize | ||
45 | self.on_after_optimize = on_after_optimize | ||
46 | self.on_eval = on_eval | ||
47 | 41 | ||
48 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) |
49 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
@@ -82,6 +76,13 @@ class LRFinder(): | |||
82 | ) | 76 | ) |
83 | progress_bar.set_description("Epoch X / Y") | 77 | progress_bar.set_description("Epoch X / Y") |
84 | 78 | ||
79 | self.callbacks.on_prepare() | ||
80 | |||
81 | on_train = self.callbacks.on_train | ||
82 | on_before_optimize = self.callbacks.on_before_optimize | ||
83 | on_after_optimize = self.callbacks.on_after_optimize | ||
84 | on_eval = self.callbacks.on_eval | ||
85 | |||
85 | for epoch in range(num_epochs): | 86 | for epoch in range(num_epochs): |
86 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
87 | 88 | ||
@@ -90,7 +91,7 @@ class LRFinder(): | |||
90 | 91 | ||
91 | self.model.train() | 92 | self.model.train() |
92 | 93 | ||
93 | with self.on_train(epoch): | 94 | with on_train(epoch): |
94 | for step, batch in enumerate(self.train_dataloader): | 95 | for step, batch in enumerate(self.train_dataloader): |
95 | if step >= num_train_batches: | 96 | if step >= num_train_batches: |
96 | break | 97 | break |
@@ -100,21 +101,21 @@ class LRFinder(): | |||
100 | 101 | ||
101 | self.accelerator.backward(loss) | 102 | self.accelerator.backward(loss) |
102 | 103 | ||
103 | self.on_before_optimize(epoch) | 104 | on_before_optimize(epoch) |
104 | 105 | ||
105 | self.optimizer.step() | 106 | self.optimizer.step() |
106 | lr_scheduler.step() | 107 | lr_scheduler.step() |
107 | self.optimizer.zero_grad(set_to_none=True) | 108 | self.optimizer.zero_grad(set_to_none=True) |
108 | 109 | ||
109 | if self.accelerator.sync_gradients: | 110 | if self.accelerator.sync_gradients: |
110 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | 111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) |
111 | 112 | ||
112 | progress_bar.update(1) | 113 | progress_bar.update(1) |
113 | 114 | ||
114 | self.model.eval() | 115 | self.model.eval() |
115 | 116 | ||
116 | with torch.inference_mode(): | 117 | with torch.inference_mode(): |
117 | with self.on_eval(): | 118 | with on_eval(): |
118 | for step, batch in enumerate(self.val_dataloader): | 119 | for step, batch in enumerate(self.val_dataloader): |
119 | if step >= num_val_batches: | 120 | if step >= num_val_batches: |
120 | break | 121 | break |