From 2469501c3951a9ed86c820cddf7b32144a4a1c8d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 19 Jan 2023 09:04:39 +0100 Subject: Move Accelerator preparation into strategy --- training/strategy/dreambooth.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -6,6 +6,7 @@ from pathlib import Path import itertools import torch +import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator @@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( ) +def dreambooth_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args +): + prep = [text_encoder, unet] + list(args) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + + dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, - prepare_unet=True + prepare=dreambooth_prepare ) -- cgit v1.2.3-54-g00ecf