diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 6 |
1 files changed, 0 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index c30d1c0..4d83df1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -34,7 +34,6 @@ def const(result=None): | |||
| 34 | 34 | ||
| 35 | @dataclass | 35 | @dataclass |
| 36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
| 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | ||
| 38 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[float, int], Any] = const() |
| @@ -461,7 +460,6 @@ def train_loop( | |||
| 461 | ) | 460 | ) |
| 462 | global_progress_bar.set_description("Total progress") | 461 | global_progress_bar.set_description("Total progress") |
| 463 | 462 | ||
| 464 | model = callbacks.on_accum_model() | ||
| 465 | on_log = callbacks.on_log | 463 | on_log = callbacks.on_log |
| 466 | on_train = callbacks.on_train | 464 | on_train = callbacks.on_train |
| 467 | on_before_optimize = callbacks.on_before_optimize | 465 | on_before_optimize = callbacks.on_before_optimize |
| @@ -498,8 +496,6 @@ def train_loop( | |||
| 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 499 | local_progress_bar.reset() | 497 | local_progress_bar.reset() |
| 500 | 498 | ||
| 501 | model.train() | ||
| 502 | |||
| 503 | with on_train(epoch): | 499 | with on_train(epoch): |
| 504 | for step, batch in enumerate(train_dataloader): | 500 | for step, batch in enumerate(train_dataloader): |
| 505 | loss, acc, bsz = loss_step(step, batch, cache) | 501 | loss, acc, bsz = loss_step(step, batch, cache) |
| @@ -560,8 +556,6 @@ def train_loop( | |||
| 560 | on_after_epoch() | 556 | on_after_epoch() |
| 561 | 557 | ||
| 562 | if val_dataloader is not None: | 558 | if val_dataloader is not None: |
| 563 | model.eval() | ||
| 564 | |||
| 565 | cur_loss_val = AverageMeter() | 559 | cur_loss_val = AverageMeter() |
| 566 | cur_acc_val = AverageMeter() | 560 | cur_acc_val = AverageMeter() |
| 567 | 561 | ||
