diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 7a3e821..a450ef6 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional, Type | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
5 | from functools import partial | 5 | from functools import partial |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | import itertools | 7 | import itertools |
@@ -37,7 +37,7 @@ class TrainingCallbacks(): | |||
37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_model: Callable[[], torch.nn.Module] = const(None) |
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
42 | on_after_epoch: Callable[[float], None] = const() | 42 | on_after_epoch: Callable[[float], None] = const() |
43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
@@ -331,13 +331,17 @@ def loss_step( | |||
331 | return loss, acc, bsz | 331 | return loss, acc, bsz |
332 | 332 | ||
333 | 333 | ||
334 | class LossCallable(Protocol): | ||
335 | def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... | ||
336 | |||
337 | |||
334 | def train_loop( | 338 | def train_loop( |
335 | accelerator: Accelerator, | 339 | accelerator: Accelerator, |
336 | optimizer: torch.optim.Optimizer, | 340 | optimizer: torch.optim.Optimizer, |
337 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 341 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
338 | train_dataloader: DataLoader, | 342 | train_dataloader: DataLoader, |
339 | val_dataloader: Optional[DataLoader], | 343 | val_dataloader: Optional[DataLoader], |
340 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 344 | loss_step: LossCallable, |
341 | sample_frequency: int = 10, | 345 | sample_frequency: int = 10, |
342 | checkpoint_frequency: int = 50, | 346 | checkpoint_frequency: int = 50, |
343 | global_step_offset: int = 0, | 347 | global_step_offset: int = 0, |
@@ -406,7 +410,7 @@ def train_loop( | |||
406 | 410 | ||
407 | accelerator.backward(loss) | 411 | accelerator.backward(loss) |
408 | 412 | ||
409 | on_before_optimize(epoch) | 413 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
410 | 414 | ||
411 | optimizer.step() | 415 | optimizer.step() |
412 | lr_scheduler.step() | 416 | lr_scheduler.step() |