summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-02 09:52:05 +0200
committerVolpeon <git@volpeon.ink>2023-04-02 09:52:05 +0200
commit229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1 (patch)
tree6ff32d5bc2a1acdab36b17c8b9175545f6d5bfe0
parentLora: Only register params with grad to optimizer (diff)
downloadtextual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.tar.gz
textual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.tar.bz2
textual-inversion-diff-229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1.zip
Update
-rw-r--r--training/functional.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index bd8cbad..b9fb546 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -39,7 +39,7 @@ class TrainingCallbacks():
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], Any] = const() 40 on_before_optimize: Callable[[float, int], Any] = const()
41 on_after_optimize: Callable[[Any, float], None] = const() 41 on_after_optimize: Callable[[Any, float], None] = const()
42 on_after_epoch: Callable[[float], None] = const() 42 on_after_epoch: Callable[[], None] = const()
43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
44 on_sample: Callable[[int], None] = const() 44 on_sample: Callable[[int], None] = const()
45 on_checkpoint: Callable[[int, str], None] = const() 45 on_checkpoint: Callable[[int, str], None] = const()
@@ -496,7 +496,7 @@ def train_loop(
496 "lr": lr, 496 "lr": lr,
497 } 497 }
498 if isDadaptation: 498 if isDadaptation:
499 logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 499 logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
500 logs.update(on_log()) 500 logs.update(on_log())
501 501
502 local_progress_bar.set_postfix(**logs) 502 local_progress_bar.set_postfix(**logs)
@@ -528,7 +528,7 @@ def train_loop(
528 528
529 lrs.append(lr) 529 lrs.append(lr)
530 530
531 on_after_epoch(lr) 531 on_after_epoch()
532 532
533 if val_dataloader is not None: 533 if val_dataloader is not None:
534 model.eval() 534 model.eval()