diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 28 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 14 | ||||
-rw-r--r-- | training/strategy/ti.py | 22 |
3 files changed, 48 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py index a450ef6..fb135c4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -7,6 +7,7 @@ from pathlib import Path | |||
7 | import itertools | 7 | import itertools |
8 | 8 | ||
9 | import torch | 9 | import torch |
10 | import torch.nn as nn | ||
10 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
11 | from torch.utils.data import DataLoader | 12 | from torch.utils.data import DataLoader |
12 | 13 | ||
@@ -45,10 +46,20 @@ class TrainingCallbacks(): | |||
45 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
46 | 47 | ||
47 | 48 | ||
49 | class TrainingStrategyPrepareCallable(Protocol): | ||
50 | def __call__( | ||
51 | self, | ||
52 | accelerator: Accelerator, | ||
53 | text_encoder: CLIPTextModel, | ||
54 | unet: UNet2DConditionModel, | ||
55 | *args | ||
56 | ) -> Tuple: ... | ||
57 | |||
58 | |||
48 | @dataclass | 59 | @dataclass |
49 | class TrainingStrategy(): | 60 | class TrainingStrategy(): |
50 | callbacks: Callable[..., TrainingCallbacks] | 61 | callbacks: Callable[..., TrainingCallbacks] |
51 | prepare_unet: bool = False | 62 | prepare: TrainingStrategyPrepareCallable |
52 | 63 | ||
53 | 64 | ||
54 | def make_grid(images, rows, cols): | 65 | def make_grid(images, rows, cols): |
@@ -535,19 +546,8 @@ def train( | |||
535 | prior_loss_weight: float = 1.0, | 546 | prior_loss_weight: float = 1.0, |
536 | **kwargs, | 547 | **kwargs, |
537 | ): | 548 | ): |
538 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 549 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
539 | 550 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | |
540 | if strategy.prepare_unet: | ||
541 | prep.append(unet) | ||
542 | |||
543 | prep = accelerator.prepare(*prep) | ||
544 | |||
545 | if strategy.prepare_unet: | ||
546 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
547 | else: | ||
548 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
549 | |||
550 | unet.to(accelerator.device, dtype=dtype) | ||
551 | 551 | ||
552 | vae.to(accelerator.device, dtype=dtype) | 552 | vae.to(accelerator.device, dtype=dtype) |
553 | 553 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -6,6 +6,7 @@ from pathlib import Path | |||
6 | import itertools | 6 | import itertools |
7 | 7 | ||
8 | import torch | 8 | import torch |
9 | import torch.nn as nn | ||
9 | from torch.utils.data import DataLoader | 10 | from torch.utils.data import DataLoader |
10 | 11 | ||
11 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
@@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( | |||
186 | ) | 187 | ) |
187 | 188 | ||
188 | 189 | ||
190 | def dreambooth_prepare( | ||
191 | accelerator: Accelerator, | ||
192 | text_encoder: CLIPTextModel, | ||
193 | unet: UNet2DConditionModel, | ||
194 | *args | ||
195 | ): | ||
196 | prep = [text_encoder, unet] + list(args) | ||
197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
199 | |||
200 | |||
189 | dreambooth_strategy = TrainingStrategy( | 201 | dreambooth_strategy = TrainingStrategy( |
190 | callbacks=dreambooth_strategy_callbacks, | 202 | callbacks=dreambooth_strategy_callbacks, |
191 | prepare_unet=True | 203 | prepare=dreambooth_prepare |
192 | ) | 204 | ) |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index e922954..6a76f98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext | |||
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | 6 | ||
7 | import torch | 7 | import torch |
8 | import torch.nn as nn | ||
8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
9 | 10 | ||
10 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
@@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks( | |||
94 | return nullcontext() | 95 | return nullcontext() |
95 | 96 | ||
96 | def on_model(): | 97 | def on_model(): |
97 | return text_encoder | 98 | return text_encoder.text_model.embeddings.temp_token_embedding |
98 | 99 | ||
99 | def on_prepare(): | 100 | def on_prepare(): |
100 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 101 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) |
@@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks( | |||
163 | ) | 164 | ) |
164 | 165 | ||
165 | 166 | ||
167 | def textual_inversion_prepare( | ||
168 | accelerator: Accelerator, | ||
169 | text_encoder: CLIPTextModel, | ||
170 | unet: UNet2DConditionModel, | ||
171 | *args | ||
172 | ): | ||
173 | weight_dtype = torch.float32 | ||
174 | if accelerator.state.mixed_precision == "fp16": | ||
175 | weight_dtype = torch.float16 | ||
176 | elif accelerator.state.mixed_precision == "bf16": | ||
177 | weight_dtype = torch.bfloat16 | ||
178 | |||
179 | prep = [text_encoder] + list(args) | ||
180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
181 | unet.to(accelerator.device, dtype=weight_dtype) | ||
182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
183 | |||
184 | |||
166 | textual_inversion_strategy = TrainingStrategy( | 185 | textual_inversion_strategy = TrainingStrategy( |
167 | callbacks=textual_inversion_strategy_callbacks, | 186 | callbacks=textual_inversion_strategy_callbacks, |
187 | prepare=textual_inversion_prepare, | ||
168 | ) | 188 | ) |