summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/training/functional.py b/training/functional.py
index a450ef6..fb135c4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -7,6 +7,7 @@ from pathlib import Path
7import itertools 7import itertools
8 8
9import torch 9import torch
10import torch.nn as nn
10import torch.nn.functional as F 11import torch.nn.functional as F
11from torch.utils.data import DataLoader 12from torch.utils.data import DataLoader
12 13
@@ -45,10 +46,20 @@ class TrainingCallbacks():
45 on_checkpoint: Callable[[int, str], None] = const() 46 on_checkpoint: Callable[[int, str], None] = const()
46 47
47 48
49class TrainingStrategyPrepareCallable(Protocol):
50 def __call__(
51 self,
52 accelerator: Accelerator,
53 text_encoder: CLIPTextModel,
54 unet: UNet2DConditionModel,
55 *args
56 ) -> Tuple: ...
57
58
48@dataclass 59@dataclass
49class TrainingStrategy(): 60class TrainingStrategy():
50 callbacks: Callable[..., TrainingCallbacks] 61 callbacks: Callable[..., TrainingCallbacks]
51 prepare_unet: bool = False 62 prepare: TrainingStrategyPrepareCallable
52 63
53 64
54def make_grid(images, rows, cols): 65def make_grid(images, rows, cols):
@@ -535,19 +546,8 @@ def train(
535 prior_loss_weight: float = 1.0, 546 prior_loss_weight: float = 1.0,
536 **kwargs, 547 **kwargs,
537): 548):
538 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] 549 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
539 550 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
540 if strategy.prepare_unet:
541 prep.append(unet)
542
543 prep = accelerator.prepare(*prep)
544
545 if strategy.prepare_unet:
546 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
547 else:
548 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
549
550 unet.to(accelerator.device, dtype=dtype)
551 551
552 vae.to(accelerator.device, dtype=dtype) 552 vae.to(accelerator.device, dtype=dtype)
553 553