From 2469501c3951a9ed86c820cddf7b32144a4a1c8d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 19 Jan 2023 09:04:39 +0100 Subject: Move Accelerator preparation into strategy --- training/strategy/dreambooth.py | 14 +++++++++++++- training/strategy/ti.py | 22 +++++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -6,6 +6,7 @@ from pathlib import Path import itertools import torch +import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator @@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( ) +def dreambooth_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args +): + prep = [text_encoder, unet] + list(args) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + + dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, - prepare_unet=True + prepare=dreambooth_prepare ) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index e922954..6a76f98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext from pathlib import Path import torch +import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator @@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks( return nullcontext() def on_model(): - return text_encoder + return text_encoder.text_model.embeddings.temp_token_embedding def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) @@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks( ) +def textual_inversion_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args +): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + prep = [text_encoder] + list(args) + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + unet.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + + textual_inversion_strategy = TrainingStrategy( callbacks=textual_inversion_strategy_callbacks, + prepare=textual_inversion_prepare, ) -- cgit v1.2.3-70-g09d2