diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
| commit | 36440e48ce279872d6e736bcb1bf57d13da73a11 (patch) | |
| tree | 8ba9593d8a887517c70b01932c137c9c3f759e8f /training | |
| parent | More training adjustments (diff) | |
| download | textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.gz textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.bz2 textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.zip | |
Moved multi-TI code from Dreambooth to TI script
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -515,6 +515,7 @@ def train( | |||
| 515 | optimizer: torch.optim.Optimizer, | 515 | optimizer: torch.optim.Optimizer, |
| 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 517 | callbacks_fn: Callable[..., TrainingCallbacks], | 517 | callbacks_fn: Callable[..., TrainingCallbacks], |
| 518 | prepare_unet: bool = False, | ||
| 518 | num_train_epochs: int = 100, | 519 | num_train_epochs: int = 100, |
| 519 | sample_frequency: int = 20, | 520 | sample_frequency: int = 20, |
| 520 | checkpoint_frequency: int = 50, | 521 | checkpoint_frequency: int = 50, |
| @@ -523,9 +524,19 @@ def train( | |||
| 523 | prior_loss_weight: float = 1.0, | 524 | prior_loss_weight: float = 1.0, |
| 524 | **kwargs, | 525 | **kwargs, |
| 525 | ): | 526 | ): |
| 526 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 527 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
| 527 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 528 | |
| 528 | ) | 529 | if prepare_unet: |
| 530 | prep.append(unet) | ||
| 531 | |||
| 532 | prep = accelerator.prepare(*prep) | ||
| 533 | |||
| 534 | if prepare_unet: | ||
| 535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
| 536 | else: | ||
| 537 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
| 538 | |||
| 539 | unet.to(accelerator.device, dtype=dtype) | ||
| 529 | 540 | ||
| 530 | vae.to(accelerator.device, dtype=dtype) | 541 | vae.to(accelerator.device, dtype=dtype) |
| 531 | 542 | ||
