summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py260
1 files changed, 239 insertions, 21 deletions
diff --git a/training/common.py b/training/common.py
index 180396e..73ce814 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,46 +1,77 @@
1import math 1import math
2from pathlib import Path
2from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple
5import datetime
6import logging
4 7
5import torch 8import torch
6import torch.nn.functional as F 9import torch.nn.functional as F
7from torch.utils.data import DataLoader 10from torch.utils.data import DataLoader
8 11
9from accelerate import Accelerator 12from accelerate import Accelerator
10from transformers import CLIPTokenizer, CLIPTextModel 13from accelerate.utils import LoggerType, set_seed
11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
13 17
14from tqdm.auto import tqdm 18from tqdm.auto import tqdm
19from slugify import slugify
15 20
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings
17from models.clip.util import get_extended_embeddings 25from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer
18from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
19from training.util import AverageMeter, CheckpointerBase 28from training.util import AverageMeter, CheckpointerBase
20 29
21 30
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
22def noop(*args, **kwards): 50def noop(*args, **kwards):
23 pass 51 pass
24 52
25 53
54def noop_ctx(*args, **kwards):
55 return nullcontext()
56
57
26def noop_on_log(): 58def noop_on_log():
27 return {} 59 return {}
28 60
29 61
30def get_scheduler( 62def get_scheduler(
31 id: str, 63 id: str,
32 min_lr: float,
33 lr: float,
34 warmup_func: str,
35 annealing_func: str,
36 warmup_exp: int,
37 annealing_exp: int,
38 cycles: int,
39 train_epochs: int,
40 warmup_epochs: int,
41 optimizer: torch.optim.Optimizer, 64 optimizer: torch.optim.Optimizer,
42 num_training_steps_per_epoch: int, 65 num_training_steps_per_epoch: int,
43 gradient_accumulation_steps: int, 66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
44): 75):
45 num_training_steps_per_epoch = math.ceil( 76 num_training_steps_per_epoch = math.ceil(
46 num_training_steps_per_epoch / gradient_accumulation_steps 77 num_training_steps_per_epoch / gradient_accumulation_steps
@@ -49,8 +80,6 @@ def get_scheduler(
49 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch 80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
50 81
51 if id == "one_cycle": 82 if id == "one_cycle":
52 min_lr = 0.04 if min_lr is None else min_lr / lr
53
54 lr_scheduler = get_one_cycle_schedule( 83 lr_scheduler = get_one_cycle_schedule(
55 optimizer=optimizer, 84 optimizer=optimizer,
56 num_training_steps=num_training_steps, 85 num_training_steps=num_training_steps,
@@ -133,6 +162,196 @@ def generate_class_images(
133 torch.cuda.empty_cache() 162 torch.cuda.empty_cache()
134 163
135 164
165def train_setup(
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler')
229
230 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True)
233
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder)
239
240 if embeddings_dir is not None:
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247
248 # Convert the initializer_token, placeholder_token to ids
249 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token
252 ]
253
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer))
256
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266
267 vae.requires_grad_(False)
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353
354
136def loss_step( 355def loss_step(
137 vae: AutoencoderKL, 356 vae: AutoencoderKL,
138 noise_scheduler: DDPMScheduler, 357 noise_scheduler: DDPMScheduler,
@@ -221,15 +440,14 @@ def train_loop(
221 sample_steps: int = 20, 440 sample_steps: int = 20,
222 checkpoint_frequency: int = 50, 441 checkpoint_frequency: int = 50,
223 global_step_offset: int = 0, 442 global_step_offset: int = 0,
224 gradient_accumulation_steps: int = 1,
225 num_epochs: int = 100, 443 num_epochs: int = 100,
226 on_log: Callable[[], dict[str, Any]] = noop_on_log, 444 on_log: Callable[[], dict[str, Any]] = noop_on_log,
227 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 445 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
228 on_before_optimize: Callable[[], None] = noop, 446 on_before_optimize: Callable[[int], None] = noop,
229 on_after_optimize: Callable[[float], None] = noop, 447 on_after_optimize: Callable[[float], None] = noop,
230 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 448 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
231): 449):
232 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 450 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
233 num_val_steps_per_epoch = len(val_dataloader) 451 num_val_steps_per_epoch = len(val_dataloader)
234 452
235 num_training_steps = num_training_steps_per_epoch * num_epochs 453 num_training_steps = num_training_steps_per_epoch * num_epochs
@@ -273,14 +491,14 @@ def train_loop(
273 491
274 model.train() 492 model.train()
275 493
276 with on_train(): 494 with on_train(epoch):
277 for step, batch in enumerate(train_dataloader): 495 for step, batch in enumerate(train_dataloader):
278 with accelerator.accumulate(model): 496 with accelerator.accumulate(model):
279 loss, acc, bsz = loss_step(step, batch) 497 loss, acc, bsz = loss_step(step, batch)
280 498
281 accelerator.backward(loss) 499 accelerator.backward(loss)
282 500
283 on_before_optimize() 501 on_before_optimize(epoch)
284 502
285 optimizer.step() 503 optimizer.step()
286 lr_scheduler.step() 504 lr_scheduler.step()