summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 14:08:49 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 14:08:49 +0100
commit96638bbd54ca7f91d44c938fae7275d3ecaa6add (patch)
treeb281a0e58820151e8738dfc5294bde5be482956b /training/functional.py
parentFix (diff)
downloadtextual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.gz
textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.bz2
textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.zip
Fixed TI normalization order
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index e7c4320..b830261 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -38,8 +38,8 @@ class TrainingCallbacks():
38 on_accum_model: Callable[[], torch.nn.Module] = const(None) 38 on_accum_model: Callable[[], torch.nn.Module] = const(None)
39 on_log: Callable[[], dict[str, Any]] = const({}) 39 on_log: Callable[[], dict[str, Any]] = const({})
40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
41 on_before_optimize: Callable[[float, int], None] = const() 41 on_before_optimize: Callable[[float, int], Any] = const()
42 on_after_optimize: Callable[[float], None] = const() 42 on_after_optimize: Callable[[Any, float], None] = const()
43 on_after_epoch: Callable[[float], None] = const() 43 on_after_epoch: Callable[[float], None] = const()
44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
45 on_sample: Callable[[int], None] = const() 45 on_sample: Callable[[int], None] = const()
@@ -455,13 +455,13 @@ def train_loop(
455 local_progress_bar.set_postfix(**logs) 455 local_progress_bar.set_postfix(**logs)
456 456
457 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 457 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
458 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 458 before_optimize_result = on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
459 459
460 optimizer.step() 460 optimizer.step()
461 lr_scheduler.step() 461 lr_scheduler.step()
462 optimizer.zero_grad(set_to_none=True) 462 optimizer.zero_grad(set_to_none=True)
463 463
464 on_after_optimize(lr_scheduler.get_last_lr()[0]) 464 on_after_optimize(before_optimize_result, lr_scheduler.get_last_lr()[0])
465 465
466 local_progress_bar.update(1) 466 local_progress_bar.update(1)
467 global_progress_bar.update(1) 467 global_progress_bar.update(1)