From 229bd1d199d1ed2cc61c07a4f34e4a14d208d4f1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Apr 2023 09:52:05 +0200 Subject: Update --- training/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/functional.py b/training/functional.py index bd8cbad..b9fb546 100644 --- a/training/functional.py +++ b/training/functional.py @@ -39,7 +39,7 @@ class TrainingCallbacks(): on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[float, int], Any] = const() on_after_optimize: Callable[[Any, float], None] = const() - on_after_epoch: Callable[[float], None] = const() + on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() @@ -496,7 +496,7 @@ def train_loop( "lr": lr, } if isDadaptation: - logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] logs.update(on_log()) local_progress_bar.set_postfix(**logs) @@ -528,7 +528,7 @@ def train_loop( lrs.append(lr) - on_after_epoch(lr) + on_after_epoch() if val_dataloader is not None: model.eval() -- cgit v1.2.3-70-g09d2