From 36440e48ce279872d6e736bcb1bf57d13da73a11 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 17:09:01 +0100 Subject: Moved multi-TI code from Dreambooth to TI script --- training/functional.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py @@ -515,6 +515,7 @@ def train( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callbacks_fn: Callable[..., TrainingCallbacks], + prepare_unet: bool = False, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -523,9 +524,19 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) + prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] + + if prepare_unet: + prep.append(unet) + + prep = accelerator.prepare(*prep) + + if prepare_unet: + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep + else: + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep + + unet.to(accelerator.device, dtype=dtype) vae.to(accelerator.device, dtype=dtype) -- cgit v1.2.3-54-g00ecf