diff options
author | Volpeon <git@volpeon.ink> | 2023-04-02 09:52:05 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-02 09:52:05 +0200 |
commit | 229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1 (patch) | |
tree | 6ff32d5bc2a1acdab36b17c8b9175545f6d5bfe0 | |
parent | Lora: Only register params with grad to optimizer (diff) | |
download | textual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.tar.gz textual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.tar.bz2 textual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.zip |
Update
-rw-r--r-- | training/functional.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index bd8cbad..b9fb546 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -39,7 +39,7 @@ class TrainingCallbacks(): | |||
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[float, int], Any] = const() | 40 | on_before_optimize: Callable[[float, int], Any] = const() |
41 | on_after_optimize: Callable[[Any, float], None] = const() | 41 | on_after_optimize: Callable[[Any, float], None] = const() |
42 | on_after_epoch: Callable[[float], None] = const() | 42 | on_after_epoch: Callable[[], None] = const() |
43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
44 | on_sample: Callable[[int], None] = const() | 44 | on_sample: Callable[[int], None] = const() |
45 | on_checkpoint: Callable[[int, str], None] = const() | 45 | on_checkpoint: Callable[[int, str], None] = const() |
@@ -496,7 +496,7 @@ def train_loop( | |||
496 | "lr": lr, | 496 | "lr": lr, |
497 | } | 497 | } |
498 | if isDadaptation: | 498 | if isDadaptation: |
499 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] | 499 | logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
500 | logs.update(on_log()) | 500 | logs.update(on_log()) |
501 | 501 | ||
502 | local_progress_bar.set_postfix(**logs) | 502 | local_progress_bar.set_postfix(**logs) |
@@ -528,7 +528,7 @@ def train_loop( | |||
528 | 528 | ||
529 | lrs.append(lr) | 529 | lrs.append(lr) |
530 | 530 | ||
531 | on_after_epoch(lr) | 531 | on_after_epoch() |
532 | 532 | ||
533 | if val_dataloader is not None: | 533 | if val_dataloader is not None: |
534 | model.eval() | 534 | model.eval() |