From 8e9d62225db11913bf7ef67221fc3508d7fe1149 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 16:39:33 +0100 Subject: Update --- training/functional.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 7a3e821..a450ef6 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Optional, Type +from typing import Callable, Any, Tuple, Union, Optional, Protocol from functools import partial from pathlib import Path import itertools @@ -37,7 +37,7 @@ class TrainingCallbacks(): on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) - on_before_optimize: Callable[[int], None] = const() + on_before_optimize: Callable[[float, int], None] = const() on_after_optimize: Callable[[float], None] = const() on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) @@ -331,13 +331,17 @@ def loss_step( return loss, acc, bsz +class LossCallable(Protocol): + def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... + + def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], + loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, global_step_offset: int = 0, @@ -406,7 +410,7 @@ def train_loop( accelerator.backward(loss) - on_before_optimize(epoch) + on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) optimizer.step() lr_scheduler.step() -- cgit v1.2.3-54-g00ecf