summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-19 09:04:39 +0100
committerVolpeon <git@volpeon.ink>2023-01-19 09:04:39 +0100
commit2469501c3951a9ed86c820cddf7b32144a4a1c8d (patch)
tree9820efaa12fd31670616c1fd9da3e6bb06580aaf /training
parentUpdate (diff)
downloadtextual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.gz
textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.bz2
textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.zip
Move Accelerator preparation into strategy
Diffstat (limited to 'training')
-rw-r--r--training/functional.py28
-rw-r--r--training/strategy/dreambooth.py14
-rw-r--r--training/strategy/ti.py22
3 files changed, 48 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py
index a450ef6..fb135c4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -7,6 +7,7 @@ from pathlib import Path
7import itertools 7import itertools
8 8
9import torch 9import torch
10import torch.nn as nn
10import torch.nn.functional as F 11import torch.nn.functional as F
11from torch.utils.data import DataLoader 12from torch.utils.data import DataLoader
12 13
@@ -45,10 +46,20 @@ class TrainingCallbacks():
45 on_checkpoint: Callable[[int, str], None] = const() 46 on_checkpoint: Callable[[int, str], None] = const()
46 47
47 48
49class TrainingStrategyPrepareCallable(Protocol):
50 def __call__(
51 self,
52 accelerator: Accelerator,
53 text_encoder: CLIPTextModel,
54 unet: UNet2DConditionModel,
55 *args
56 ) -> Tuple: ...
57
58
48@dataclass 59@dataclass
49class TrainingStrategy(): 60class TrainingStrategy():
50 callbacks: Callable[..., TrainingCallbacks] 61 callbacks: Callable[..., TrainingCallbacks]
51 prepare_unet: bool = False 62 prepare: TrainingStrategyPrepareCallable
52 63
53 64
54def make_grid(images, rows, cols): 65def make_grid(images, rows, cols):
@@ -535,19 +546,8 @@ def train(
535 prior_loss_weight: float = 1.0, 546 prior_loss_weight: float = 1.0,
536 **kwargs, 547 **kwargs,
537): 548):
538 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] 549 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
539 550 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
540 if strategy.prepare_unet:
541 prep.append(unet)
542
543 prep = accelerator.prepare(*prep)
544
545 if strategy.prepare_unet:
546 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
547 else:
548 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
549
550 unet.to(accelerator.device, dtype=dtype)
551 551
552 vae.to(accelerator.device, dtype=dtype) 552 vae.to(accelerator.device, dtype=dtype)
553 553
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index f57e736..1277939 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -6,6 +6,7 @@ from pathlib import Path
6import itertools 6import itertools
7 7
8import torch 8import torch
9import torch.nn as nn
9from torch.utils.data import DataLoader 10from torch.utils.data import DataLoader
10 11
11from accelerate import Accelerator 12from accelerate import Accelerator
@@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks(
186 ) 187 )
187 188
188 189
190def dreambooth_prepare(
191 accelerator: Accelerator,
192 text_encoder: CLIPTextModel,
193 unet: UNet2DConditionModel,
194 *args
195):
196 prep = [text_encoder, unet] + list(args)
197 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
198 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199
200
189dreambooth_strategy = TrainingStrategy( 201dreambooth_strategy = TrainingStrategy(
190 callbacks=dreambooth_strategy_callbacks, 202 callbacks=dreambooth_strategy_callbacks,
191 prepare_unet=True 203 prepare=dreambooth_prepare
192) 204)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index e922954..6a76f98 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext
5from pathlib import Path 5from pathlib import Path
6 6
7import torch 7import torch
8import torch.nn as nn
8from torch.utils.data import DataLoader 9from torch.utils.data import DataLoader
9 10
10from accelerate import Accelerator 11from accelerate import Accelerator
@@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks(
94 return nullcontext() 95 return nullcontext()
95 96
96 def on_model(): 97 def on_model():
97 return text_encoder 98 return text_encoder.text_model.embeddings.temp_token_embedding
98 99
99 def on_prepare(): 100 def on_prepare():
100 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) 101 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
@@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks(
163 ) 164 )
164 165
165 166
167def textual_inversion_prepare(
168 accelerator: Accelerator,
169 text_encoder: CLIPTextModel,
170 unet: UNet2DConditionModel,
171 *args
172):
173 weight_dtype = torch.float32
174 if accelerator.state.mixed_precision == "fp16":
175 weight_dtype = torch.float16
176 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16
178
179 prep = [text_encoder] + list(args)
180 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
181 unet.to(accelerator.device, dtype=weight_dtype)
182 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
183
184
166textual_inversion_strategy = TrainingStrategy( 185textual_inversion_strategy = TrainingStrategy(
167 callbacks=textual_inversion_strategy_callbacks, 186 callbacks=textual_inversion_strategy_callbacks,
187 prepare=textual_inversion_prepare,
168) 188)