summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py9
1 files changed, 2 insertions, 7 deletions
diff --git a/training/functional.py b/training/functional.py
index a5b339d..ee73ab2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_prepare: Callable[[], None] = const()
38 on_accum_model: Callable[[], torch.nn.Module] = const(None) 37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
39 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
@@ -620,10 +619,8 @@ def train(
620 kwargs.update(extra) 619 kwargs.update(extra)
621 620
622 vae.to(accelerator.device, dtype=dtype) 621 vae.to(accelerator.device, dtype=dtype)
623 622 vae.requires_grad_(False)
624 for model in (unet, text_encoder, vae): 623 vae.eval()
625 model.requires_grad_(False)
626 model.eval()
627 624
628 callbacks = strategy.callbacks( 625 callbacks = strategy.callbacks(
629 accelerator=accelerator, 626 accelerator=accelerator,
@@ -636,8 +633,6 @@ def train(
636 **kwargs, 633 **kwargs,
637 ) 634 )
638 635
639 callbacks.on_prepare()
640
641 loss_step_ = partial( 636 loss_step_ = partial(
642 loss_step, 637 loss_step,
643 vae, 638 vae,