diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -515,6 +515,7 @@ def train( | |||
515 | optimizer: torch.optim.Optimizer, | 515 | optimizer: torch.optim.Optimizer, |
516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
517 | callbacks_fn: Callable[..., TrainingCallbacks], | 517 | callbacks_fn: Callable[..., TrainingCallbacks], |
518 | prepare_unet: bool = False, | ||
518 | num_train_epochs: int = 100, | 519 | num_train_epochs: int = 100, |
519 | sample_frequency: int = 20, | 520 | sample_frequency: int = 20, |
520 | checkpoint_frequency: int = 50, | 521 | checkpoint_frequency: int = 50, |
@@ -523,9 +524,19 @@ def train( | |||
523 | prior_loss_weight: float = 1.0, | 524 | prior_loss_weight: float = 1.0, |
524 | **kwargs, | 525 | **kwargs, |
525 | ): | 526 | ): |
526 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 527 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
527 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 528 | |
528 | ) | 529 | if prepare_unet: |
530 | prep.append(unet) | ||
531 | |||
532 | prep = accelerator.prepare(*prep) | ||
533 | |||
534 | if prepare_unet: | ||
535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
536 | else: | ||
537 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
538 | |||
539 | unet.to(accelerator.device, dtype=dtype) | ||
529 | 540 | ||
530 | vae.to(accelerator.device, dtype=dtype) | 541 | vae.to(accelerator.device, dtype=dtype) |
531 | 542 | ||