diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 6 |
1 files changed, 0 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index c30d1c0..4d83df1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,6 @@ def const(result=None): | |||
34 | 34 | ||
35 | @dataclass | 35 | @dataclass |
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | ||
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[float, int], Any] = const() |
@@ -461,7 +460,6 @@ def train_loop( | |||
461 | ) | 460 | ) |
462 | global_progress_bar.set_description("Total progress") | 461 | global_progress_bar.set_description("Total progress") |
463 | 462 | ||
464 | model = callbacks.on_accum_model() | ||
465 | on_log = callbacks.on_log | 463 | on_log = callbacks.on_log |
466 | on_train = callbacks.on_train | 464 | on_train = callbacks.on_train |
467 | on_before_optimize = callbacks.on_before_optimize | 465 | on_before_optimize = callbacks.on_before_optimize |
@@ -498,8 +496,6 @@ def train_loop( | |||
498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
499 | local_progress_bar.reset() | 497 | local_progress_bar.reset() |
500 | 498 | ||
501 | model.train() | ||
502 | |||
503 | with on_train(epoch): | 499 | with on_train(epoch): |
504 | for step, batch in enumerate(train_dataloader): | 500 | for step, batch in enumerate(train_dataloader): |
505 | loss, acc, bsz = loss_step(step, batch, cache) | 501 | loss, acc, bsz = loss_step(step, batch, cache) |
@@ -560,8 +556,6 @@ def train_loop( | |||
560 | on_after_epoch() | 556 | on_after_epoch() |
561 | 557 | ||
562 | if val_dataloader is not None: | 558 | if val_dataloader is not None: |
563 | model.eval() | ||
564 | |||
565 | cur_loss_val = AverageMeter() | 559 | cur_loss_val = AverageMeter() |
566 | cur_acc_val = AverageMeter() | 560 | cur_acc_val = AverageMeter() |
567 | 561 | ||