diff options
| -rw-r--r-- | train_ti.py | 6 | ||||
| -rw-r--r-- | training/functional.py | 28 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 14 | ||||
| -rw-r--r-- | training/strategy/ti.py | 22 |
4 files changed, 51 insertions, 19 deletions
diff --git a/train_ti.py b/train_ti.py index 7aa4960..451b61b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -159,7 +159,7 @@ def parse_args(): | |||
| 159 | parser.add_argument( | 159 | parser.add_argument( |
| 160 | "--tag_dropout", | 160 | "--tag_dropout", |
| 161 | type=float, | 161 | type=float, |
| 162 | default=0.1, | 162 | default=0, |
| 163 | help="Tag dropout probability.", | 163 | help="Tag dropout probability.", |
| 164 | ) | 164 | ) |
| 165 | parser.add_argument( | 165 | parser.add_argument( |
| @@ -407,7 +407,7 @@ def parse_args(): | |||
| 407 | ) | 407 | ) |
| 408 | parser.add_argument( | 408 | parser.add_argument( |
| 409 | "--emb_decay", | 409 | "--emb_decay", |
| 410 | default=1e-2, | 410 | default=10, |
| 411 | type=float, | 411 | type=float, |
| 412 | help="Embedding decay factor." | 412 | help="Embedding decay factor." |
| 413 | ) | 413 | ) |
| @@ -597,7 +597,7 @@ def main(): | |||
| 597 | 597 | ||
| 598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 599 | if len(placeholder_tokens) == 1: | 599 | if len(placeholder_tokens) == 1: |
| 600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") | 600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
| 601 | else: | 601 | else: |
| 602 | sample_output_dir = output_dir.joinpath("samples") | 602 | sample_output_dir = output_dir.joinpath("samples") |
| 603 | 603 | ||
diff --git a/training/functional.py b/training/functional.py index a450ef6..fb135c4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -7,6 +7,7 @@ from pathlib import Path | |||
| 7 | import itertools | 7 | import itertools |
| 8 | 8 | ||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn as nn | ||
| 10 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
| 11 | from torch.utils.data import DataLoader | 12 | from torch.utils.data import DataLoader |
| 12 | 13 | ||
| @@ -45,10 +46,20 @@ class TrainingCallbacks(): | |||
| 45 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
| 46 | 47 | ||
| 47 | 48 | ||
| 49 | class TrainingStrategyPrepareCallable(Protocol): | ||
| 50 | def __call__( | ||
| 51 | self, | ||
| 52 | accelerator: Accelerator, | ||
| 53 | text_encoder: CLIPTextModel, | ||
| 54 | unet: UNet2DConditionModel, | ||
| 55 | *args | ||
| 56 | ) -> Tuple: ... | ||
| 57 | |||
| 58 | |||
| 48 | @dataclass | 59 | @dataclass |
| 49 | class TrainingStrategy(): | 60 | class TrainingStrategy(): |
| 50 | callbacks: Callable[..., TrainingCallbacks] | 61 | callbacks: Callable[..., TrainingCallbacks] |
| 51 | prepare_unet: bool = False | 62 | prepare: TrainingStrategyPrepareCallable |
| 52 | 63 | ||
| 53 | 64 | ||
| 54 | def make_grid(images, rows, cols): | 65 | def make_grid(images, rows, cols): |
| @@ -535,19 +546,8 @@ def train( | |||
| 535 | prior_loss_weight: float = 1.0, | 546 | prior_loss_weight: float = 1.0, |
| 536 | **kwargs, | 547 | **kwargs, |
| 537 | ): | 548 | ): |
| 538 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 549 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
| 539 | 550 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | |
| 540 | if strategy.prepare_unet: | ||
| 541 | prep.append(unet) | ||
| 542 | |||
| 543 | prep = accelerator.prepare(*prep) | ||
| 544 | |||
| 545 | if strategy.prepare_unet: | ||
| 546 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
| 547 | else: | ||
| 548 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
| 549 | |||
| 550 | unet.to(accelerator.device, dtype=dtype) | ||
| 551 | 551 | ||
| 552 | vae.to(accelerator.device, dtype=dtype) | 552 | vae.to(accelerator.device, dtype=dtype) |
| 553 | 553 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -6,6 +6,7 @@ from pathlib import Path | |||
| 6 | import itertools | 6 | import itertools |
| 7 | 7 | ||
| 8 | import torch | 8 | import torch |
| 9 | import torch.nn as nn | ||
| 9 | from torch.utils.data import DataLoader | 10 | from torch.utils.data import DataLoader |
| 10 | 11 | ||
| 11 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
| @@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( | |||
| 186 | ) | 187 | ) |
| 187 | 188 | ||
| 188 | 189 | ||
| 190 | def dreambooth_prepare( | ||
| 191 | accelerator: Accelerator, | ||
| 192 | text_encoder: CLIPTextModel, | ||
| 193 | unet: UNet2DConditionModel, | ||
| 194 | *args | ||
| 195 | ): | ||
| 196 | prep = [text_encoder, unet] + list(args) | ||
| 197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | |||
| 200 | |||
| 189 | dreambooth_strategy = TrainingStrategy( | 201 | dreambooth_strategy = TrainingStrategy( |
| 190 | callbacks=dreambooth_strategy_callbacks, | 202 | callbacks=dreambooth_strategy_callbacks, |
| 191 | prepare_unet=True | 203 | prepare=dreambooth_prepare |
| 192 | ) | 204 | ) |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index e922954..6a76f98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | import torch.nn as nn | ||
| 8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
| 9 | 10 | ||
| 10 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| @@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks( | |||
| 94 | return nullcontext() | 95 | return nullcontext() |
| 95 | 96 | ||
| 96 | def on_model(): | 97 | def on_model(): |
| 97 | return text_encoder | 98 | return text_encoder.text_model.embeddings.temp_token_embedding |
| 98 | 99 | ||
| 99 | def on_prepare(): | 100 | def on_prepare(): |
| 100 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 101 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) |
| @@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks( | |||
| 163 | ) | 164 | ) |
| 164 | 165 | ||
| 165 | 166 | ||
| 167 | def textual_inversion_prepare( | ||
| 168 | accelerator: Accelerator, | ||
| 169 | text_encoder: CLIPTextModel, | ||
| 170 | unet: UNet2DConditionModel, | ||
| 171 | *args | ||
| 172 | ): | ||
| 173 | weight_dtype = torch.float32 | ||
| 174 | if accelerator.state.mixed_precision == "fp16": | ||
| 175 | weight_dtype = torch.float16 | ||
| 176 | elif accelerator.state.mixed_precision == "bf16": | ||
| 177 | weight_dtype = torch.bfloat16 | ||
| 178 | |||
| 179 | prep = [text_encoder] + list(args) | ||
| 180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 181 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 183 | |||
| 184 | |||
| 166 | textual_inversion_strategy = TrainingStrategy( | 185 | textual_inversion_strategy = TrainingStrategy( |
| 167 | callbacks=textual_inversion_strategy_callbacks, | 186 | callbacks=textual_inversion_strategy_callbacks, |
| 187 | prepare=textual_inversion_prepare, | ||
| 168 | ) | 188 | ) |
