From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- training/functional.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,6 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_prepare: Callable[[], None] = const() on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) @@ -620,10 +619,8 @@ def train( kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() + vae.requires_grad_(False) + vae.eval() callbacks = strategy.callbacks( accelerator=accelerator, @@ -636,8 +633,6 @@ def train( **kwargs, ) - callbacks.on_prepare() - loss_step_ = partial( loss_step, vae, -- cgit v1.2.3-54-g00ecf