diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-17 16:39:33 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-17 16:39:33 +0100 |
| commit | 8e9d62225db11913bf7ef67221fc3508d7fe1149 (patch) | |
| tree | 4c17e8491a77bc92deb276dedba7949a8bb7297a /training/functional.py | |
| parent | Optimized embedding normalization (diff) | |
| download | textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.gz textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.bz2 textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.zip | |
Update
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 7a3e821..a450ef6 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
| 2 | import math | 2 | import math |
| 3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
| 4 | from typing import Callable, Any, Tuple, Union, Optional, Type | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
| 5 | from functools import partial | 5 | from functools import partial |
| 6 | from pathlib import Path | 6 | from pathlib import Path |
| 7 | import itertools | 7 | import itertools |
| @@ -37,7 +37,7 @@ class TrainingCallbacks(): | |||
| 37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_model: Callable[[], torch.nn.Module] = const(None) |
| 38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
| 41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
| 42 | on_after_epoch: Callable[[float], None] = const() | 42 | on_after_epoch: Callable[[float], None] = const() |
| 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| @@ -331,13 +331,17 @@ def loss_step( | |||
| 331 | return loss, acc, bsz | 331 | return loss, acc, bsz |
| 332 | 332 | ||
| 333 | 333 | ||
| 334 | class LossCallable(Protocol): | ||
| 335 | def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... | ||
| 336 | |||
| 337 | |||
| 334 | def train_loop( | 338 | def train_loop( |
| 335 | accelerator: Accelerator, | 339 | accelerator: Accelerator, |
| 336 | optimizer: torch.optim.Optimizer, | 340 | optimizer: torch.optim.Optimizer, |
| 337 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 341 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 338 | train_dataloader: DataLoader, | 342 | train_dataloader: DataLoader, |
| 339 | val_dataloader: Optional[DataLoader], | 343 | val_dataloader: Optional[DataLoader], |
| 340 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 344 | loss_step: LossCallable, |
| 341 | sample_frequency: int = 10, | 345 | sample_frequency: int = 10, |
| 342 | checkpoint_frequency: int = 50, | 346 | checkpoint_frequency: int = 50, |
| 343 | global_step_offset: int = 0, | 347 | global_step_offset: int = 0, |
| @@ -406,7 +410,7 @@ def train_loop( | |||
| 406 | 410 | ||
| 407 | accelerator.backward(loss) | 411 | accelerator.backward(loss) |
| 408 | 412 | ||
| 409 | on_before_optimize(epoch) | 413 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
| 410 | 414 | ||
| 411 | optimizer.step() | 415 | optimizer.step() |
| 412 | lr_scheduler.step() | 416 | lr_scheduler.step() |
