From 2469501c3951a9ed86c820cddf7b32144a4a1c8d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 19 Jan 2023 09:04:39 +0100 Subject: Move Accelerator preparation into strategy --- training/functional.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index a450ef6..fb135c4 100644 --- a/training/functional.py +++ b/training/functional.py @@ -7,6 +7,7 @@ from pathlib import Path import itertools import torch +import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -45,10 +46,20 @@ class TrainingCallbacks(): on_checkpoint: Callable[[int, str], None] = const() +class TrainingStrategyPrepareCallable(Protocol): + def __call__( + self, + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + *args + ) -> Tuple: ... + + @dataclass class TrainingStrategy(): callbacks: Callable[..., TrainingCallbacks] - prepare_unet: bool = False + prepare: TrainingStrategyPrepareCallable def make_grid(images, rows, cols): @@ -535,19 +546,8 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] - - if strategy.prepare_unet: - prep.append(unet) - - prep = accelerator.prepare(*prep) - - if strategy.prepare_unet: - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep - else: - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep - - unet.to(accelerator.device, dtype=dtype) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( + accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) vae.to(accelerator.device, dtype=dtype) -- cgit v1.2.3-54-g00ecf