summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 16:39:33 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 16:39:33 +0100
commit8e9d62225db11913bf7ef67221fc3508d7fe1149 (patch)
tree4c17e8491a77bc92deb276dedba7949a8bb7297a /training/functional.py
parentOptimized embedding normalization (diff)
downloadtextual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.gz
textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.bz2
textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index 7a3e821..a450ef6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,7 +1,7 @@
1from dataclasses import dataclass 1from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional, Type 4from typing import Callable, Any, Tuple, Union, Optional, Protocol
5from functools import partial 5from functools import partial
6from pathlib import Path 6from pathlib import Path
7import itertools 7import itertools
@@ -37,7 +37,7 @@ class TrainingCallbacks():
37 on_model: Callable[[], torch.nn.Module] = const(None) 37 on_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[int], None] = const() 40 on_before_optimize: Callable[[float, int], None] = const()
41 on_after_optimize: Callable[[float], None] = const() 41 on_after_optimize: Callable[[float], None] = const()
42 on_after_epoch: Callable[[float], None] = const() 42 on_after_epoch: Callable[[float], None] = const()
43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
@@ -331,13 +331,17 @@ def loss_step(
331 return loss, acc, bsz 331 return loss, acc, bsz
332 332
333 333
334class LossCallable(Protocol):
335 def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ...
336
337
334def train_loop( 338def train_loop(
335 accelerator: Accelerator, 339 accelerator: Accelerator,
336 optimizer: torch.optim.Optimizer, 340 optimizer: torch.optim.Optimizer,
337 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 341 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
338 train_dataloader: DataLoader, 342 train_dataloader: DataLoader,
339 val_dataloader: Optional[DataLoader], 343 val_dataloader: Optional[DataLoader],
340 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 344 loss_step: LossCallable,
341 sample_frequency: int = 10, 345 sample_frequency: int = 10,
342 checkpoint_frequency: int = 50, 346 checkpoint_frequency: int = 50,
343 global_step_offset: int = 0, 347 global_step_offset: int = 0,
@@ -406,7 +410,7 @@ def train_loop(
406 410
407 accelerator.backward(loss) 411 accelerator.backward(loss)
408 412
409 on_before_optimize(epoch) 413 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
410 414
411 optimizer.step() 415 optimizer.step()
412 lr_scheduler.step() 416 lr_scheduler.step()