summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index e14aeea..46d25f6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -644,11 +644,9 @@ def train(
644 min_snr_gamma: int = 5, 644 min_snr_gamma: int = 5,
645 **kwargs, 645 **kwargs,
646): 646):
647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs)
649 649
650 kwargs.update(extra)
651
652 vae.to(accelerator.device, dtype=dtype) 650 vae.to(accelerator.device, dtype=dtype)
653 vae.requires_grad_(False) 651 vae.requires_grad_(False)
654 vae.eval() 652 vae.eval()