summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py35
-rw-r--r--training/strategy/lora.py147
-rw-r--r--training/strategy/ti.py38
3 files changed, 203 insertions, 17 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e88bf90..b4c77f3 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks(
61 save_samples_ = partial( 61 save_samples_ = partial(
62 save_samples, 62 save_samples,
63 accelerator=accelerator, 63 accelerator=accelerator,
64 unet=unet,
65 text_encoder=text_encoder,
66 tokenizer=tokenizer, 64 tokenizer=tokenizer,
67 vae=vae, 65 vae=vae,
68 sample_scheduler=sample_scheduler, 66 sample_scheduler=sample_scheduler,
69 train_dataloader=train_dataloader, 67 train_dataloader=train_dataloader,
70 val_dataloader=val_dataloader, 68 val_dataloader=val_dataloader,
71 dtype=weight_dtype,
72 output_dir=sample_output_dir, 69 output_dir=sample_output_dir,
73 seed=seed, 70 seed=seed,
74 batch_size=sample_batch_size, 71 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return unet 95 return unet
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks(
172 @torch.no_grad() 169 @torch.no_grad()
173 def on_sample(step): 170 def on_sample(step):
174 with ema_context(): 171 with ema_context():
175 save_samples_(step=step) 172 unet_ = accelerator.unwrap_model(unet)
173 text_encoder_ = accelerator.unwrap_model(text_encoder)
174
175 orig_unet_dtype = unet_.dtype
176 orig_text_encoder_dtype = text_encoder_.dtype
177
178 unet_.to(dtype=weight_dtype)
179 text_encoder_.to(dtype=weight_dtype)
180
181 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
182
183 unet_.to(dtype=orig_unet_dtype)
184 text_encoder_.to(dtype=orig_text_encoder_dtype)
185
186 del unet_
187 del text_encoder_
188
189 if torch.cuda.is_available():
190 torch.cuda.empty_cache()
176 191
177 return TrainingCallbacks( 192 return TrainingCallbacks(
178 on_prepare=on_prepare, 193 on_prepare=on_prepare,
179 on_model=on_model, 194 on_accum_model=on_accum_model,
180 on_train=on_train, 195 on_train=on_train,
181 on_eval=on_eval, 196 on_eval=on_eval,
182 on_before_optimize=on_before_optimize, 197 on_before_optimize=on_before_optimize,
@@ -191,9 +206,13 @@ def dreambooth_prepare(
191 accelerator: Accelerator, 206 accelerator: Accelerator,
192 text_encoder: CLIPTextModel, 207 text_encoder: CLIPTextModel,
193 unet: UNet2DConditionModel, 208 unet: UNet2DConditionModel,
194 *args 209 optimizer: torch.optim.Optimizer,
210 train_dataloader: DataLoader,
211 val_dataloader: Optional[DataLoader],
212 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
213 **kwargs
195): 214):
196 return accelerator.prepare(text_encoder, unet, *args) 215 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({})
197 216
198 217
199dreambooth_strategy = TrainingStrategy( 218dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
new file mode 100644
index 0000000..88d1824
--- /dev/null
+++ b/training/strategy/lora.py
@@ -0,0 +1,147 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader
10
11from accelerate import Accelerator
12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14from diffusers.loaders import AttnProcsLayers
15
16from slugify import slugify
17
18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import EMAModel
20from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
21
22
23def lora_strategy_callbacks(
24 accelerator: Accelerator,
25 unet: UNet2DConditionModel,
26 text_encoder: CLIPTextModel,
27 tokenizer: MultiCLIPTokenizer,
28 vae: AutoencoderKL,
29 sample_scheduler: DPMSolverMultistepScheduler,
30 train_dataloader: DataLoader,
31 val_dataloader: Optional[DataLoader],
32 sample_output_dir: Path,
33 checkpoint_output_dir: Path,
34 seed: int,
35 lora_layers: AttnProcsLayers,
36 max_grad_norm: float = 1.0,
37 sample_batch_size: int = 1,
38 sample_num_batches: int = 1,
39 sample_num_steps: int = 20,
40 sample_guidance_scale: float = 7.5,
41 sample_image_size: Optional[int] = None,
42):
43 sample_output_dir.mkdir(parents=True, exist_ok=True)
44 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
45
46 weight_dtype = torch.float32
47 if accelerator.state.mixed_precision == "fp16":
48 weight_dtype = torch.float16
49 elif accelerator.state.mixed_precision == "bf16":
50 weight_dtype = torch.bfloat16
51
52 save_samples_ = partial(
53 save_samples,
54 accelerator=accelerator,
55 unet=unet,
56 text_encoder=text_encoder,
57 tokenizer=tokenizer,
58 vae=vae,
59 sample_scheduler=sample_scheduler,
60 train_dataloader=train_dataloader,
61 val_dataloader=val_dataloader,
62 output_dir=sample_output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 def on_prepare():
72 lora_layers.requires_grad_(True)
73
74 def on_accum_model():
75 return unet
76
77 @contextmanager
78 def on_train(epoch: int):
79 tokenizer.train()
80 yield
81
82 @contextmanager
83 def on_eval():
84 tokenizer.eval()
85 yield
86
87 def on_before_optimize(lr: float, epoch: int):
88 if accelerator.sync_gradients:
89 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
90
91 @torch.no_grad()
92 def on_checkpoint(step, postfix):
93 print(f"Saving checkpoint for step {step}...")
94 orig_unet_dtype = unet.dtype
95 unet.to(dtype=torch.float32)
96 unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
97 unet.to(dtype=orig_unet_dtype)
98
99 @torch.no_grad()
100 def on_sample(step):
101 orig_unet_dtype = unet.dtype
102 unet.to(dtype=weight_dtype)
103 save_samples_(step=step)
104 unet.to(dtype=orig_unet_dtype)
105
106 if torch.cuda.is_available():
107 torch.cuda.empty_cache()
108
109 return TrainingCallbacks(
110 on_prepare=on_prepare,
111 on_accum_model=on_accum_model,
112 on_train=on_train,
113 on_eval=on_eval,
114 on_before_optimize=on_before_optimize,
115 on_checkpoint=on_checkpoint,
116 on_sample=on_sample,
117 )
118
119
120def lora_prepare(
121 accelerator: Accelerator,
122 text_encoder: CLIPTextModel,
123 unet: UNet2DConditionModel,
124 optimizer: torch.optim.Optimizer,
125 train_dataloader: DataLoader,
126 val_dataloader: Optional[DataLoader],
127 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
128 lora_layers: AttnProcsLayers,
129 **kwargs
130):
131 weight_dtype = torch.float32
132 if accelerator.state.mixed_precision == "fp16":
133 weight_dtype = torch.float16
134 elif accelerator.state.mixed_precision == "bf16":
135 weight_dtype = torch.bfloat16
136
137 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
138 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
139 unet.to(accelerator.device, dtype=weight_dtype)
140 text_encoder.to(accelerator.device, dtype=weight_dtype)
141 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
142
143
144lora_strategy = TrainingStrategy(
145 callbacks=lora_strategy_callbacks,
146 prepare=lora_prepare,
147)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 14bdafd..d306f18 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks(
59 save_samples_ = partial( 59 save_samples_ = partial(
60 save_samples, 60 save_samples,
61 accelerator=accelerator, 61 accelerator=accelerator,
62 unet=unet,
63 text_encoder=text_encoder,
64 tokenizer=tokenizer, 62 tokenizer=tokenizer,
65 vae=vae, 63 vae=vae,
66 sample_scheduler=sample_scheduler, 64 sample_scheduler=sample_scheduler,
67 train_dataloader=train_dataloader, 65 train_dataloader=train_dataloader,
68 val_dataloader=val_dataloader, 66 val_dataloader=val_dataloader,
69 dtype=weight_dtype,
70 output_dir=sample_output_dir, 67 output_dir=sample_output_dir,
71 seed=seed, 68 seed=seed,
72 batch_size=sample_batch_size, 69 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return text_encoder.text_model.embeddings.temp_token_embedding 95 return text_encoder.text_model.embeddings.temp_token_embedding
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks(
149 @torch.no_grad() 146 @torch.no_grad()
150 def on_sample(step): 147 def on_sample(step):
151 with ema_context(): 148 with ema_context():
152 save_samples_(step=step) 149 unet_ = accelerator.unwrap_model(unet)
150 text_encoder_ = accelerator.unwrap_model(text_encoder)
151
152 orig_unet_dtype = unet_.dtype
153 orig_text_encoder_dtype = text_encoder_.dtype
154
155 unet_.to(dtype=weight_dtype)
156 text_encoder_.to(dtype=weight_dtype)
157
158 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
159
160 unet_.to(dtype=orig_unet_dtype)
161 text_encoder_.to(dtype=orig_text_encoder_dtype)
162
163 del unet_
164 del text_encoder_
165
166 if torch.cuda.is_available():
167 torch.cuda.empty_cache()
153 168
154 return TrainingCallbacks( 169 return TrainingCallbacks(
155 on_prepare=on_prepare, 170 on_prepare=on_prepare,
156 on_model=on_model, 171 on_accum_model=on_accum_model,
157 on_train=on_train, 172 on_train=on_train,
158 on_eval=on_eval, 173 on_eval=on_eval,
159 on_before_optimize=on_before_optimize, 174 on_before_optimize=on_before_optimize,
@@ -168,7 +183,11 @@ def textual_inversion_prepare(
168 accelerator: Accelerator, 183 accelerator: Accelerator,
169 text_encoder: CLIPTextModel, 184 text_encoder: CLIPTextModel,
170 unet: UNet2DConditionModel, 185 unet: UNet2DConditionModel,
171 *args 186 optimizer: torch.optim.Optimizer,
187 train_dataloader: DataLoader,
188 val_dataloader: Optional[DataLoader],
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs
172): 191):
173 weight_dtype = torch.float32 192 weight_dtype = torch.float32
174 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -176,9 +195,10 @@ def textual_inversion_prepare(
176 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
178 197
179 prepped = accelerator.prepare(text_encoder, *args) 198 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
180 unet.to(accelerator.device, dtype=weight_dtype) 200 unet.to(accelerator.device, dtype=weight_dtype)
181 return (prepped[0], unet) + prepped[1:] 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
182 202
183 203
184textual_inversion_strategy = TrainingStrategy( 204textual_inversion_strategy = TrainingStrategy(