diff options
author | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
commit | 2469501c3951a9ed86c820cddf7b32144a4a1c8d (patch) | |
tree | 9820efaa12fd31670616c1fd9da3e6bb06580aaf /training/functional.py | |
parent | Update (diff) | |
download | textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.gz textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.bz2 textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.zip |
Move Accelerator preparation into strategy
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/training/functional.py b/training/functional.py index a450ef6..fb135c4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -7,6 +7,7 @@ from pathlib import Path | |||
7 | import itertools | 7 | import itertools |
8 | 8 | ||
9 | import torch | 9 | import torch |
10 | import torch.nn as nn | ||
10 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
11 | from torch.utils.data import DataLoader | 12 | from torch.utils.data import DataLoader |
12 | 13 | ||
@@ -45,10 +46,20 @@ class TrainingCallbacks(): | |||
45 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
46 | 47 | ||
47 | 48 | ||
49 | class TrainingStrategyPrepareCallable(Protocol): | ||
50 | def __call__( | ||
51 | self, | ||
52 | accelerator: Accelerator, | ||
53 | text_encoder: CLIPTextModel, | ||
54 | unet: UNet2DConditionModel, | ||
55 | *args | ||
56 | ) -> Tuple: ... | ||
57 | |||
58 | |||
48 | @dataclass | 59 | @dataclass |
49 | class TrainingStrategy(): | 60 | class TrainingStrategy(): |
50 | callbacks: Callable[..., TrainingCallbacks] | 61 | callbacks: Callable[..., TrainingCallbacks] |
51 | prepare_unet: bool = False | 62 | prepare: TrainingStrategyPrepareCallable |
52 | 63 | ||
53 | 64 | ||
54 | def make_grid(images, rows, cols): | 65 | def make_grid(images, rows, cols): |
@@ -535,19 +546,8 @@ def train( | |||
535 | prior_loss_weight: float = 1.0, | 546 | prior_loss_weight: float = 1.0, |
536 | **kwargs, | 547 | **kwargs, |
537 | ): | 548 | ): |
538 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 549 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
539 | 550 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | |
540 | if strategy.prepare_unet: | ||
541 | prep.append(unet) | ||
542 | |||
543 | prep = accelerator.prepare(*prep) | ||
544 | |||
545 | if strategy.prepare_unet: | ||
546 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
547 | else: | ||
548 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
549 | |||
550 | unet.to(accelerator.device, dtype=dtype) | ||
551 | 551 | ||
552 | vae.to(accelerator.device, dtype=dtype) | 552 | vae.to(accelerator.device, dtype=dtype) |
553 | 553 | ||