From 2469501c3951a9ed86c820cddf7b32144a4a1c8d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 19 Jan 2023 09:04:39 +0100 Subject: Move Accelerator preparation into strategy --- train_ti.py | 6 +++--- training/functional.py | 28 ++++++++++++++-------------- training/strategy/dreambooth.py | 14 +++++++++++++- training/strategy/ti.py | 22 +++++++++++++++++++++- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/train_ti.py b/train_ti.py index 7aa4960..451b61b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,7 +159,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0.1, + default=0, help="Tag dropout probability.", ) parser.add_argument( @@ -407,7 +407,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=1e-2, + default=10, type=float, help="Embedding decay factor." ) @@ -597,7 +597,7 @@ def main(): def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): if len(placeholder_tokens) == 1: - sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") + sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") else: sample_output_dir = output_dir.joinpath("samples") diff --git a/training/functional.py b/training/functional.py index a450ef6..fb135c4 100644 --- a/training/functional.py +++ b/training/functional.py @@ -7,6 +7,7 @@ from pathlib import Path import itertools import torch +import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -45,10 +46,20 @@ class TrainingCallbacks(): on_checkpoint: Callable[[int, str], None] = const() +class TrainingStrategyPrepareCallable(Protocol): + def __call__( + self, + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args + ) -> Tuple: ... + + @dataclass class TrainingStrategy(): callbacks: Callable[..., TrainingCallbacks] - prepare_unet: bool = False + prepare: TrainingStrategyPrepareCallable def make_grid(images, rows, cols): @@ -535,19 +546,8 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] - - if strategy.prepare_unet: - prep.append(unet) - - prep = accelerator.prepare(*prep) - - if strategy.prepare_unet: - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep - else: - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep - - unet.to(accelerator.device, dtype=dtype) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( + accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) vae.to(accelerator.device, dtype=dtype) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -6,6 +6,7 @@ from pathlib import Path import itertools import torch +import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator @@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( ) +def dreambooth_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args +): + prep = [text_encoder, unet] + list(args) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + + dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, - prepare_unet=True + prepare=dreambooth_prepare ) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index e922954..6a76f98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext from pathlib import Path import torch +import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator @@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks( return nullcontext() def on_model(): - return text_encoder + return text_encoder.text_model.embeddings.temp_token_embedding def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) @@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks( ) +def textual_inversion_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args +): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + prep = [text_encoder] + list(args) + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + unet.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + + textual_inversion_strategy = TrainingStrategy( callbacks=textual_inversion_strategy_callbacks, + prepare=textual_inversion_prepare, ) -- cgit v1.2.3-54-g00ecf