diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 | 
| commit | 96638bbd54ca7f91d44c938fae7275d3ecaa6add (patch) | |
| tree | b281a0e58820151e8738dfc5294bde5be482956b /training/functional.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.gz textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.bz2 textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.zip  | |
Fixed TI normalization order
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 8 | 
1 files changed, 4 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index e7c4320..b830261 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -38,8 +38,8 @@ class TrainingCallbacks(): | |||
| 38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 
| 39 | on_log: Callable[[], dict[str, Any]] = const({}) | 39 | on_log: Callable[[], dict[str, Any]] = const({}) | 
| 40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 
| 41 | on_before_optimize: Callable[[float, int], None] = const() | 41 | on_before_optimize: Callable[[float, int], Any] = const() | 
| 42 | on_after_optimize: Callable[[float], None] = const() | 42 | on_after_optimize: Callable[[Any, float], None] = const() | 
| 43 | on_after_epoch: Callable[[float], None] = const() | 43 | on_after_epoch: Callable[[float], None] = const() | 
| 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 
| 45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int], None] = const() | 
| @@ -455,13 +455,13 @@ def train_loop( | |||
| 455 | local_progress_bar.set_postfix(**logs) | 455 | local_progress_bar.set_postfix(**logs) | 
| 456 | 456 | ||
| 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 
| 458 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 458 | before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 
| 459 | 459 | ||
| 460 | optimizer.step() | 460 | optimizer.step() | 
| 461 | lr_scheduler.step() | 461 | lr_scheduler.step() | 
| 462 | optimizer.zero_grad(set_to_none=True) | 462 | optimizer.zero_grad(set_to_none=True) | 
| 463 | 463 | ||
| 464 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 464 | on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0]) | 
| 465 | 465 | ||
| 466 | local_progress_bar.update(1) | 466 | local_progress_bar.update(1) | 
| 467 | global_progress_bar.update(1) | 467 | global_progress_bar.update(1) | 
