summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index b6b5d87..1548784 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -515,6 +515,7 @@ def train(
515 optimizer: torch.optim.Optimizer, 515 optimizer: torch.optim.Optimizer,
516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
517 callbacks_fn: Callable[..., TrainingCallbacks], 517 callbacks_fn: Callable[..., TrainingCallbacks],
518 prepare_unet: bool = False,
518 num_train_epochs: int = 100, 519 num_train_epochs: int = 100,
519 sample_frequency: int = 20, 520 sample_frequency: int = 20,
520 checkpoint_frequency: int = 50, 521 checkpoint_frequency: int = 50,
@@ -523,9 +524,19 @@ def train(
523 prior_loss_weight: float = 1.0, 524 prior_loss_weight: float = 1.0,
524 **kwargs, 525 **kwargs,
525): 526):
526 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 527 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
527 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 528
528 ) 529 if prepare_unet:
530 prep.append(unet)
531
532 prep = accelerator.prepare(*prep)
533
534 if prepare_unet:
535 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
536 else:
537 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
538
539 unet.to(accelerator.device, dtype=dtype)
529 540
530 vae.to(accelerator.device, dtype=dtype) 541 vae.to(accelerator.device, dtype=dtype)
531 542