diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,6 @@ def const(result=None): | |||
34 | 34 | ||
35 | @dataclass | 35 | @dataclass |
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_prepare: Callable[[], None] = const() | ||
38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
39 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
@@ -620,10 +619,8 @@ def train( | |||
620 | kwargs.update(extra) | 619 | kwargs.update(extra) |
621 | 620 | ||
622 | vae.to(accelerator.device, dtype=dtype) | 621 | vae.to(accelerator.device, dtype=dtype) |
623 | 622 | vae.requires_grad_(False) | |
624 | for model in (unet, text_encoder, vae): | 623 | vae.eval() |
625 | model.requires_grad_(False) | ||
626 | model.eval() | ||
627 | 624 | ||
628 | callbacks = strategy.callbacks( | 625 | callbacks = strategy.callbacks( |
629 | accelerator=accelerator, | 626 | accelerator=accelerator, |
@@ -636,8 +633,6 @@ def train( | |||
636 | **kwargs, | 633 | **kwargs, |
637 | ) | 634 | ) |
638 | 635 | ||
639 | callbacks.on_prepare() | ||
640 | |||
641 | loss_step_ = partial( | 636 | loss_step_ = partial( |
642 | loss_step, | 637 | loss_step, |
643 | vae, | 638 | vae, |