diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
commit | a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 (patch) | |
tree | 8bd97a745e1113b1035c504ec484e099f878aed0 /training | |
parent | Various updates (diff) | |
download | textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.gz textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.bz2 textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/training/lr.py b/training/lr.py index c8dc040..3cdf994 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -12,7 +12,7 @@ from tqdm.auto import tqdm | |||
12 | from training.util import AverageMeter | 12 | from training.util import AverageMeter |
13 | 13 | ||
14 | 14 | ||
15 | def noop(): | 15 | def noop(*args, **kwards): |
16 | pass | 16 | pass |
17 | 17 | ||
18 | 18 | ||
@@ -26,6 +26,7 @@ class LRFinder(): | |||
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
29 | on_clip: Callable[[], None] = noop, | ||
29 | on_eval: Callable[[], None] = noop | 30 | on_eval: Callable[[], None] = noop |
30 | ): | 31 | ): |
31 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
@@ -35,6 +36,7 @@ class LRFinder(): | |||
35 | self.val_dataloader = val_dataloader | 36 | self.val_dataloader = val_dataloader |
36 | self.loss_fn = loss_fn | 37 | self.loss_fn = loss_fn |
37 | self.on_train = on_train | 38 | self.on_train = on_train |
39 | self.on_clip = on_clip | ||
38 | self.on_eval = on_eval | 40 | self.on_eval = on_eval |
39 | 41 | ||
40 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) |
@@ -93,6 +95,9 @@ class LRFinder(): | |||
93 | 95 | ||
94 | self.accelerator.backward(loss) | 96 | self.accelerator.backward(loss) |
95 | 97 | ||
98 | if self.accelerator.sync_gradients: | ||
99 | self.on_clip() | ||
100 | |||
96 | self.optimizer.step() | 101 | self.optimizer.step() |
97 | lr_scheduler.step() | 102 | lr_scheduler.step() |
98 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |