diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -34,7 +34,6 @@ def const(result=None): | |||
| 34 | 34 | ||
| 35 | @dataclass | 35 | @dataclass |
| 36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
| 37 | on_prepare: Callable[[], None] = const() | ||
| 38 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
| 39 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 40 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| @@ -620,10 +619,8 @@ def train( | |||
| 620 | kwargs.update(extra) | 619 | kwargs.update(extra) |
| 621 | 620 | ||
| 622 | vae.to(accelerator.device, dtype=dtype) | 621 | vae.to(accelerator.device, dtype=dtype) |
| 623 | 622 | vae.requires_grad_(False) | |
| 624 | for model in (unet, text_encoder, vae): | 623 | vae.eval() |
| 625 | model.requires_grad_(False) | ||
| 626 | model.eval() | ||
| 627 | 624 | ||
| 628 | callbacks = strategy.callbacks( | 625 | callbacks = strategy.callbacks( |
| 629 | accelerator=accelerator, | 626 | accelerator=accelerator, |
| @@ -636,8 +633,6 @@ def train( | |||
| 636 | **kwargs, | 633 | **kwargs, |
| 637 | ) | 634 | ) |
| 638 | 635 | ||
| 639 | callbacks.on_prepare() | ||
| 640 | |||
| 641 | loss_step_ = partial( | 636 | loss_step_ = partial( |
| 642 | loss_step, | 637 | loss_step, |
| 643 | vae, | 638 | vae, |
