summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py6
1 files changed, 0 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index c30d1c0..4d83df1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], Any] = const() 39 on_before_optimize: Callable[[float, int], Any] = const()
@@ -461,7 +460,6 @@ def train_loop(
461 ) 460 )
462 global_progress_bar.set_description("Total progress") 461 global_progress_bar.set_description("Total progress")
463 462
464 model = callbacks.on_accum_model()
465 on_log = callbacks.on_log 463 on_log = callbacks.on_log
466 on_train = callbacks.on_train 464 on_train = callbacks.on_train
467 on_before_optimize = callbacks.on_before_optimize 465 on_before_optimize = callbacks.on_before_optimize
@@ -498,8 +496,6 @@ def train_loop(
498 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 496 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
499 local_progress_bar.reset() 497 local_progress_bar.reset()
500 498
501 model.train()
502
503 with on_train(epoch): 499 with on_train(epoch):
504 for step, batch in enumerate(train_dataloader): 500 for step, batch in enumerate(train_dataloader):
505 loss, acc, bsz = loss_step(step, batch, cache) 501 loss, acc, bsz = loss_step(step, batch, cache)
@@ -560,8 +556,6 @@ def train_loop(
560 on_after_epoch() 556 on_after_epoch()
561 557
562 if val_dataloader is not None: 558 if val_dataloader is not None:
563 model.eval()
564
565 cur_loss_val = AverageMeter() 559 cur_loss_val = AverageMeter()
566 cur_acc_val = AverageMeter() 560 cur_acc_val = AverageMeter()
567 561