summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py221
-rw-r--r--training/lr.py4
-rw-r--r--training/optimization.py38
-rw-r--r--training/sampler.py2
-rw-r--r--training/strategy/dreambooth.py29
-rw-r--r--training/strategy/lora.py41
-rw-r--r--training/strategy/ti.py27
7 files changed, 245 insertions, 117 deletions
diff --git a/training/functional.py b/training/functional.py
index fd3f9f4..f68faf9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -14,7 +14,13 @@ import numpy as np
14 14
15from accelerate import Accelerator 15from accelerate import Accelerator
16from transformers import CLIPTextModel 16from transformers import CLIPTextModel
17from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin 17from diffusers import (
18 AutoencoderKL,
19 UNet2DConditionModel,
20 DDPMScheduler,
21 UniPCMultistepScheduler,
22 SchedulerMixin,
23)
18 24
19from tqdm.auto import tqdm 25from tqdm.auto import tqdm
20 26
@@ -33,11 +39,12 @@ from util.noise import perlin_noise
33def const(result=None): 39def const(result=None):
34 def fn(*args, **kwargs): 40 def fn(*args, **kwargs):
35 return result 41 return result
42
36 return fn 43 return fn
37 44
38 45
39@dataclass 46@dataclass
40class TrainingCallbacks(): 47class TrainingCallbacks:
41 on_log: Callable[[], dict[str, Any]] = const({}) 48 on_log: Callable[[], dict[str, Any]] = const({})
42 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 49 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
43 on_before_optimize: Callable[[int], Any] = const() 50 on_before_optimize: Callable[[int], Any] = const()
@@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol):
58 train_dataloader: DataLoader, 65 train_dataloader: DataLoader,
59 val_dataloader: Optional[DataLoader], 66 val_dataloader: Optional[DataLoader],
60 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 67 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
61 **kwargs 68 **kwargs,
62 ) -> Tuple: ... 69 ) -> Tuple:
70 ...
63 71
64 72
65@dataclass 73@dataclass
66class TrainingStrategy(): 74class TrainingStrategy:
67 callbacks: Callable[..., TrainingCallbacks] 75 callbacks: Callable[..., TrainingCallbacks]
68 prepare: TrainingStrategyPrepareCallable 76 prepare: TrainingStrategyPrepareCallable
69 77
70 78
71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): 79def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 80 tokenizer = MultiCLIPTokenizer.from_pretrained(
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) 81 pretrained_model_name_or_path, subfolder="tokenizer"
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) 82 )
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) 83 text_encoder = CLIPTextModel.from_pretrained(
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 85 )
86 vae = AutoencoderKL.from_pretrained(
87 pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype
88 )
89 unet = UNet2DConditionModel.from_pretrained(
90 pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype
91 )
92 noise_scheduler = DDPMScheduler.from_pretrained(
93 pretrained_model_name_or_path, subfolder="scheduler"
94 )
95 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
96 pretrained_model_name_or_path, subfolder="scheduler"
97 )
78 98
79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 99 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
80 100
@@ -113,7 +133,9 @@ def save_samples(
113 133
114 generator = torch.Generator(device=accelerator.device).manual_seed(seed) 134 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
115 135
116 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] 136 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [
137 ("train", train_dataloader, None)
138 ]
117 139
118 if val_dataloader is not None: 140 if val_dataloader is not None:
119 datasets.append(("stable", val_dataloader, generator)) 141 datasets.append(("stable", val_dataloader, generator))
@@ -124,17 +146,11 @@ def save_samples(
124 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" 146 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg"
125 file_path.parent.mkdir(parents=True, exist_ok=True) 147 file_path.parent.mkdir(parents=True, exist_ok=True)
126 148
127 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 149 batches = list(
128 prompt_ids = [ 150 itertools.islice(itertools.cycle(data), batch_size * num_batches)
129 prompt 151 )
130 for batch in batches 152 prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]]
131 for prompt in batch["prompt_ids"] 153 nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]]
132 ]
133 nprompt_ids = [
134 prompt
135 for batch in batches
136 for prompt in batch["nprompt_ids"]
137 ]
138 154
139 with torch.inference_mode(): 155 with torch.inference_mode():
140 for i in range(num_batches): 156 for i in range(num_batches):
@@ -165,7 +181,9 @@ def save_samples(
165 pass 181 pass
166 182
167 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 183 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 184 image_grid = pipeline.numpy_to_pil(
185 image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy()
186 )[0]
169 image_grid.save(file_path, quality=85) 187 image_grid.save(file_path, quality=85)
170 188
171 del generator, pipeline 189 del generator, pipeline
@@ -184,15 +202,17 @@ def generate_class_images(
184 train_dataset: VlpnDataset, 202 train_dataset: VlpnDataset,
185 sample_batch_size: int, 203 sample_batch_size: int,
186 sample_image_size: int, 204 sample_image_size: int,
187 sample_steps: int 205 sample_steps: int,
188): 206):
189 missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] 207 missing_data = [
208 item for item in train_dataset.items if not item.class_image_path.exists()
209 ]
190 210
191 if len(missing_data) == 0: 211 if len(missing_data) == 0:
192 return 212 return
193 213
194 batched_data = [ 214 batched_data = [
195 missing_data[i:i+sample_batch_size] 215 missing_data[i : i + sample_batch_size]
196 for i in range(0, len(missing_data), sample_batch_size) 216 for i in range(0, len(missing_data), sample_batch_size)
197 ] 217 ]
198 218
@@ -216,7 +236,7 @@ def generate_class_images(
216 negative_prompt=nprompt, 236 negative_prompt=nprompt,
217 height=sample_image_size, 237 height=sample_image_size,
218 width=sample_image_size, 238 width=sample_image_size,
219 num_inference_steps=sample_steps 239 num_inference_steps=sample_steps,
220 ).images 240 ).images
221 241
222 for i, image in enumerate(images): 242 for i, image in enumerate(images):
@@ -245,8 +265,12 @@ def add_placeholder_tokens(
245 265
246 embeddings.resize(len(tokenizer)) 266 embeddings.resize(len(tokenizer))
247 267
248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 268 for placeholder_token_id, initializer_token_id in zip(
249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) 269 placeholder_token_ids, initializer_token_ids
270 ):
271 embeddings.add_embed(
272 placeholder_token_id, initializer_token_id, initializer_noise
273 )
250 274
251 return placeholder_token_ids, initializer_token_ids 275 return placeholder_token_ids, initializer_token_ids
252 276
@@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler):
261 285
262 # Expand the tensors. 286 # Expand the tensors.
263 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 287 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
264 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 288 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
289 timesteps
290 ].float()
265 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 291 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
266 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 292 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
267 alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 293 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
268 294
269 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 295 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
296 device=timesteps.device
297 )[timesteps].float()
270 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 298 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
271 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 299 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
272 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 300 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
@@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler):
277 305
278 306
279def get_original( 307def get_original(
280 noise_scheduler, 308 noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor
281 model_output,
282 sample: torch.FloatTensor,
283 timesteps: torch.IntTensor
284): 309):
285 alphas_cumprod = noise_scheduler.alphas_cumprod 310 alphas_cumprod = noise_scheduler.alphas_cumprod
286 sqrt_alphas_cumprod = alphas_cumprod**0.5 311 sqrt_alphas_cumprod = alphas_cumprod**0.5
287 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 312 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
288 313
289 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 314 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
315 timesteps
316 ].float()
290 while len(sqrt_alphas_cumprod.shape) < len(sample.shape): 317 while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
291 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 318 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
292 alpha = sqrt_alphas_cumprod.expand(sample.shape) 319 alpha = sqrt_alphas_cumprod.expand(sample.shape)
293 320
294 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 321 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
322 device=timesteps.device
323 )[timesteps].float()
295 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): 324 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
296 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 325 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
297 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) 326 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
@@ -329,7 +358,9 @@ def loss_step(
329 eval: bool = False, 358 eval: bool = False,
330): 359):
331 images = batch["pixel_values"] 360 images = batch["pixel_values"]
332 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None 361 generator = (
362 torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
363 )
333 bsz = images.shape[0] 364 bsz = images.shape[0]
334 365
335 # Convert images to latent space 366 # Convert images to latent space
@@ -342,7 +373,7 @@ def loss_step(
342 dtype=latents.dtype, 373 dtype=latents.dtype,
343 layout=latents.layout, 374 layout=latents.layout,
344 device=latents.device, 375 device=latents.device,
345 generator=generator 376 generator=generator,
346 ) 377 )
347 applied_noise = noise 378 applied_noise = noise
348 379
@@ -353,7 +384,7 @@ def loss_step(
353 octaves=4, 384 octaves=4,
354 dtype=latents.dtype, 385 dtype=latents.dtype,
355 device=latents.device, 386 device=latents.device,
356 generator=generator 387 generator=generator,
357 ) 388 )
358 389
359 if input_pertubation != 0: 390 if input_pertubation != 0:
@@ -362,7 +393,7 @@ def loss_step(
362 dtype=latents.dtype, 393 dtype=latents.dtype,
363 layout=latents.layout, 394 layout=latents.layout,
364 device=latents.device, 395 device=latents.device,
365 generator=generator 396 generator=generator,
366 ) 397 )
367 398
368 # Sample a random timestep for each image 399 # Sample a random timestep for each image
@@ -375,25 +406,27 @@ def loss_step(
375 406
376 # Get the text embedding for conditioning 407 # Get the text embedding for conditioning
377 encoder_hidden_states = get_extended_embeddings( 408 encoder_hidden_states = get_extended_embeddings(
378 text_encoder, 409 text_encoder, batch["input_ids"], batch["attention_mask"]
379 batch["input_ids"],
380 batch["attention_mask"]
381 ) 410 )
382 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 411 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
383 412
384 # Predict the noise residual 413 # Predict the noise residual
385 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] 414 model_pred = unet(
415 noisy_latents, timesteps, encoder_hidden_states, return_dict=False
416 )[0]
386 417
387 if guidance_scale != 0: 418 if guidance_scale != 0:
388 uncond_encoder_hidden_states = get_extended_embeddings( 419 uncond_encoder_hidden_states = get_extended_embeddings(
389 text_encoder, 420 text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"]
390 batch["negative_input_ids"],
391 batch["negative_attention_mask"]
392 ) 421 )
393 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) 422 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
394 423
395 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] 424 model_pred_uncond = unet(
396 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 425 noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False
426 )[0]
427 model_pred = model_pred_uncond + guidance_scale * (
428 model_pred - model_pred_uncond
429 )
397 430
398 # Get the target for loss depending on the prediction type 431 # Get the target for loss depending on the prediction type
399 if noise_scheduler.config.prediction_type == "epsilon": 432 if noise_scheduler.config.prediction_type == "epsilon":
@@ -401,7 +434,9 @@ def loss_step(
401 elif noise_scheduler.config.prediction_type == "v_prediction": 434 elif noise_scheduler.config.prediction_type == "v_prediction":
402 target = noise_scheduler.get_velocity(latents, noise, timesteps) 435 target = noise_scheduler.get_velocity(latents, noise, timesteps)
403 else: 436 else:
404 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 437 raise ValueError(
438 f"Unknown prediction type {noise_scheduler.config.prediction_type}"
439 )
405 440
406 acc = (model_pred == target).float().mean() 441 acc = (model_pred == target).float().mean()
407 442
@@ -414,7 +449,9 @@ def loss_step(
414 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 449 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
415 450
416 # Compute prior loss 451 # Compute prior loss
417 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 452 prior_loss = F.mse_loss(
453 model_pred_prior.float(), target_prior.float(), reduction="none"
454 )
418 455
419 # Add the prior loss to the instance loss. 456 # Add the prior loss to the instance loss.
420 loss = loss + prior_loss_weight * prior_loss 457 loss = loss + prior_loss_weight * prior_loss
@@ -433,7 +470,10 @@ def loss_step(
433 if min_snr_gamma != 0: 470 if min_snr_gamma != 0:
434 snr = compute_snr(timesteps, noise_scheduler) 471 snr = compute_snr(timesteps, noise_scheduler)
435 mse_loss_weights = ( 472 mse_loss_weights = (
436 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 473 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(
474 dim=1
475 )[0]
476 / snr
437 ) 477 )
438 loss = loss * mse_loss_weights 478 loss = loss * mse_loss_weights
439 479
@@ -447,8 +487,14 @@ def loss_step(
447 487
448 488
449class LossCallable(Protocol): 489class LossCallable(Protocol):
450 def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], 490 def __call__(
451 eval: bool = False) -> Tuple[Any, Any, int]: ... 491 self,
492 step: int,
493 batch: dict[Any, Any],
494 cache: dict[str, Any],
495 eval: bool = False,
496 ) -> Tuple[Any, Any, int]:
497 ...
452 498
453 499
454def train_loop( 500def train_loop(
@@ -472,9 +518,14 @@ def train_loop(
472 avg_acc_val: AverageMeter = AverageMeter(), 518 avg_acc_val: AverageMeter = AverageMeter(),
473 callbacks: TrainingCallbacks = TrainingCallbacks(), 519 callbacks: TrainingCallbacks = TrainingCallbacks(),
474): 520):
475 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 521 num_training_steps_per_epoch = math.ceil(
476 num_val_steps_per_epoch = math.ceil( 522 len(train_dataloader) / gradient_accumulation_steps
477 len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 523 )
524 num_val_steps_per_epoch = (
525 math.ceil(len(val_dataloader) / gradient_accumulation_steps)
526 if val_dataloader is not None
527 else 0
528 )
478 529
479 num_training_steps = num_training_steps_per_epoch * num_epochs 530 num_training_steps = num_training_steps_per_epoch * num_epochs
480 num_val_steps = num_val_steps_per_epoch * num_epochs 531 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -488,14 +539,14 @@ def train_loop(
488 local_progress_bar = tqdm( 539 local_progress_bar = tqdm(
489 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 540 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
490 disable=not accelerator.is_local_main_process, 541 disable=not accelerator.is_local_main_process,
491 dynamic_ncols=True 542 dynamic_ncols=True,
492 ) 543 )
493 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") 544 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
494 545
495 global_progress_bar = tqdm( 546 global_progress_bar = tqdm(
496 range(num_training_steps + num_val_steps), 547 range(num_training_steps + num_val_steps),
497 disable=not accelerator.is_local_main_process, 548 disable=not accelerator.is_local_main_process,
498 dynamic_ncols=True 549 dynamic_ncols=True,
499 ) 550 )
500 global_progress_bar.set_description("Total progress") 551 global_progress_bar.set_description("Total progress")
501 552
@@ -513,7 +564,9 @@ def train_loop(
513 try: 564 try:
514 import dadaptation 565 import dadaptation
515 566
516 isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) 567 isDadaptation = isinstance(
568 optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)
569 )
517 except ImportError: 570 except ImportError:
518 pass 571 pass
519 572
@@ -565,7 +618,10 @@ def train_loop(
565 label = group_labels[i] if i < len(group_labels) else f"{i}" 618 label = group_labels[i] if i < len(group_labels) else f"{i}"
566 logs[f"lr/{label}"] = lr 619 logs[f"lr/{label}"] = lr
567 if isDadaptation: 620 if isDadaptation:
568 lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] 621 lr = (
622 optimizer.param_groups[i]["d"]
623 * optimizer.param_groups[i]["lr"]
624 )
569 logs[f"d*lr/{label}"] = lr 625 logs[f"d*lr/{label}"] = lr
570 lrs[label] = lr 626 lrs[label] = lr
571 627
@@ -573,8 +629,10 @@ def train_loop(
573 629
574 local_progress_bar.set_postfix(**logs) 630 local_progress_bar.set_postfix(**logs)
575 631
576 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 632 if ((step + 1) % gradient_accumulation_steps == 0) or (
577 before_optimize_result = on_before_optimize(epoch) 633 (step + 1) == len(train_dataloader)
634 ):
635 before_optimize_result = on_before_optimize(cycle)
578 636
579 optimizer.step() 637 optimizer.step()
580 lr_scheduler.step() 638 lr_scheduler.step()
@@ -614,7 +672,9 @@ def train_loop(
614 } 672 }
615 local_progress_bar.set_postfix(**logs) 673 local_progress_bar.set_postfix(**logs)
616 674
617 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): 675 if ((step + 1) % gradient_accumulation_steps == 0) or (
676 (step + 1) == len(val_dataloader)
677 ):
618 local_progress_bar.update(1) 678 local_progress_bar.update(1)
619 global_progress_bar.update(1) 679 global_progress_bar.update(1)
620 680
@@ -634,7 +694,8 @@ def train_loop(
634 global_progress_bar.clear() 694 global_progress_bar.clear()
635 695
636 accelerator.print( 696 accelerator.print(
637 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") 697 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}"
698 )
638 on_checkpoint(global_step, "milestone") 699 on_checkpoint(global_step, "milestone")
639 best_acc_val = avg_acc_val.max 700 best_acc_val = avg_acc_val.max
640 else: 701 else:
@@ -644,7 +705,8 @@ def train_loop(
644 global_progress_bar.clear() 705 global_progress_bar.clear()
645 706
646 accelerator.print( 707 accelerator.print(
647 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") 708 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}"
709 )
648 on_checkpoint(global_step, "milestone") 710 on_checkpoint(global_step, "milestone")
649 best_acc = avg_acc.max 711 best_acc = avg_acc.max
650 712
@@ -700,17 +762,32 @@ def train(
700 avg_acc_val: AverageMeter = AverageMeter(), 762 avg_acc_val: AverageMeter = AverageMeter(),
701 **kwargs, 763 **kwargs,
702): 764):
703 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 765 (
704 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 766 text_encoder,
767 unet,
768 optimizer,
769 train_dataloader,
770 val_dataloader,
771 lr_scheduler,
772 ) = strategy.prepare(
773 accelerator,
774 text_encoder,
775 unet,
776 optimizer,
777 train_dataloader,
778 val_dataloader,
779 lr_scheduler,
780 **kwargs,
781 )
705 782
706 vae.to(accelerator.device, dtype=dtype) 783 vae.to(accelerator.device, dtype=dtype)
707 vae.requires_grad_(False) 784 vae.requires_grad_(False)
708 vae.eval() 785 vae.eval()
709 786
710 vae = torch.compile(vae, backend='hidet') 787 vae = torch.compile(vae, backend="hidet")
711 788
712 if compile_unet: 789 if compile_unet:
713 unet = torch.compile(unet, backend='hidet') 790 unet = torch.compile(unet, backend="hidet")
714 # unet = torch.compile(unet, mode="reduce-overhead") 791 # unet = torch.compile(unet, mode="reduce-overhead")
715 792
716 callbacks = strategy.callbacks( 793 callbacks = strategy.callbacks(
diff --git a/training/lr.py b/training/lr.py
index f5b362f..a75078f 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -23,12 +23,12 @@ def plot_metrics(
23 fig, ax_loss = plt.subplots() 23 fig, ax_loss = plt.subplots()
24 ax_acc = ax_loss.twinx() 24 ax_acc = ax_loss.twinx()
25 25
26 ax_loss.plot(lrs, losses, color='red') 26 ax_loss.plot(lrs, losses, color="red")
27 ax_loss.set_xscale("log") 27 ax_loss.set_xscale("log")
28 ax_loss.set_xlabel(f"Learning rate") 28 ax_loss.set_xlabel(f"Learning rate")
29 ax_loss.set_ylabel("Loss") 29 ax_loss.set_ylabel("Loss")
30 30
31 ax_acc.plot(lrs, accs, color='blue') 31 ax_acc.plot(lrs, accs, color="blue")
32 ax_acc.set_xscale("log") 32 ax_acc.set_xscale("log")
33 ax_acc.set_ylabel("Accuracy") 33 ax_acc.set_ylabel("Accuracy")
34 34
diff --git a/training/optimization.py b/training/optimization.py
index d22a900..55531bf 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,7 +5,10 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import (
9 get_scheduler as get_scheduler_,
10 get_cosine_with_hard_restarts_schedule_with_warmup,
11)
9from transformers.optimization import get_adafactor_schedule 12from transformers.optimization import get_adafactor_schedule
10 13
11 14
@@ -52,7 +55,7 @@ def get_one_cycle_schedule(
52 annealing_exp: int = 1, 55 annealing_exp: int = 1,
53 min_lr: float = 0.04, 56 min_lr: float = 0.04,
54 mid_point: float = 0.3, 57 mid_point: float = 0.3,
55 last_epoch: int = -1 58 last_epoch: int = -1,
56): 59):
57 if warmup == "linear": 60 if warmup == "linear":
58 warmup_func = warmup_linear 61 warmup_func = warmup_linear
@@ -83,12 +86,16 @@ def get_one_cycle_schedule(
83 86
84 def lr_lambda(current_step: int): 87 def lr_lambda(current_step: int):
85 phase = [p for p in phases if current_step >= p.step_min][-1] 88 phase = [p for p in phases if current_step >= p.step_min][-1]
86 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 89 return phase.min + phase.func(
90 (current_step - phase.step_min) / (phase.step_max - phase.step_min)
91 ) * (phase.max - phase.min)
87 92
88 return LambdaLR(optimizer, lr_lambda, last_epoch) 93 return LambdaLR(optimizer, lr_lambda, last_epoch)
89 94
90 95
91def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): 96def get_exponential_growing_schedule(
97 optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1
98):
92 def lr_lambda(base_lr: float, current_step: int): 99 def lr_lambda(base_lr: float, current_step: int):
93 return (end_lr / base_lr) ** (current_step / num_training_steps) 100 return (end_lr / base_lr) ** (current_step / num_training_steps)
94 101
@@ -132,7 +139,14 @@ def get_scheduler(
132 ) 139 )
133 elif id == "exponential_growth": 140 elif id == "exponential_growth":
134 if cycles is None: 141 if cycles is None:
135 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 142 cycles = math.ceil(
143 math.sqrt(
144 (
145 (num_training_steps - num_warmup_steps)
146 / num_training_steps_per_epoch
147 )
148 )
149 )
136 150
137 lr_scheduler = get_exponential_growing_schedule( 151 lr_scheduler = get_exponential_growing_schedule(
138 optimizer=optimizer, 152 optimizer=optimizer,
@@ -141,7 +155,14 @@ def get_scheduler(
141 ) 155 )
142 elif id == "cosine_with_restarts": 156 elif id == "cosine_with_restarts":
143 if cycles is None: 157 if cycles is None:
144 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 158 cycles = math.ceil(
159 math.sqrt(
160 (
161 (num_training_steps - num_warmup_steps)
162 / num_training_steps_per_epoch
163 )
164 )
165 )
145 166
146 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 167 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
147 optimizer=optimizer, 168 optimizer=optimizer,
@@ -150,10 +171,7 @@ def get_scheduler(
150 num_cycles=cycles, 171 num_cycles=cycles,
151 ) 172 )
152 elif id == "adafactor": 173 elif id == "adafactor":
153 lr_scheduler = get_adafactor_schedule( 174 lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr)
154 optimizer,
155 initial_lr=min_lr
156 )
157 else: 175 else:
158 lr_scheduler = get_scheduler_( 176 lr_scheduler = get_scheduler_(
159 id, 177 id,
diff --git a/training/sampler.py b/training/sampler.py
index bdb3e90..0487d66 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler):
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64) 136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 137 weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
138 weights /= np.sum(weights) 138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob 139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights) 140 weights += self.uniform_prob / len(weights)
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e6fcc89..88b441b 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks(
29 sample_output_dir: Path, 29 sample_output_dir: Path,
30 checkpoint_output_dir: Path, 30 checkpoint_output_dir: Path,
31 seed: int, 31 seed: int,
32 train_text_encoder_epochs: int, 32 train_text_encoder_cycles: int,
33 max_grad_norm: float = 1.0, 33 max_grad_norm: float = 1.0,
34 use_ema: bool = False, 34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks(
85 return nullcontext() 85 return nullcontext()
86 86
87 @contextmanager 87 @contextmanager
88 def on_train(epoch: int): 88 def on_train(cycle: int):
89 unet.train() 89 unet.train()
90 tokenizer.train() 90 tokenizer.train()
91 91
92 if epoch < train_text_encoder_epochs: 92 if cycle < train_text_encoder_cycles:
93 text_encoder.train() 93 text_encoder.train()
94 elif epoch == train_text_encoder_epochs: 94 tokenizer.train()
95 text_encoder.requires_grad_(False)
96 text_encoder.eval()
97 95
98 yield 96 yield
99 97
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 104 with ema_context():
107 yield 105 yield
108 106
109 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
110 params_to_clip = [unet.parameters()] 108 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 109 if cycle < train_text_encoder_cycles:
112 params_to_clip.append(text_encoder.parameters()) 110 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 111 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 112
@@ -189,8 +187,16 @@ def dreambooth_prepare(
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs 188 **kwargs
191): 189):
192 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 190 (
193 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 191 text_encoder,
192 unet,
193 optimizer,
194 train_dataloader,
195 val_dataloader,
196 lr_scheduler,
197 ) = accelerator.prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 )
194 200
195 text_encoder.text_model.embeddings.requires_grad_(False) 201 text_encoder.text_model.embeddings.requires_grad_(False)
196 202
@@ -198,6 +204,5 @@ def dreambooth_prepare(
198 204
199 205
200dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
201 callbacks=dreambooth_strategy_callbacks, 207 callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare
202 prepare=dreambooth_prepare
203) 208)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index f942b76..14e3384 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -81,7 +81,7 @@ def lora_strategy_callbacks(
81 tokenizer.eval() 81 tokenizer.eval()
82 yield 82 yield
83 83
84 def on_before_optimize(epoch: int): 84 def on_before_optimize(cycle: int):
85 if not pti_mode: 85 if not pti_mode:
86 accelerator.clip_grad_norm_( 86 accelerator.clip_grad_norm_(
87 itertools.chain( 87 itertools.chain(
@@ -89,7 +89,7 @@ def lora_strategy_callbacks(
89 text_encoder.text_model.encoder.parameters(), 89 text_encoder.text_model.encoder.parameters(),
90 text_encoder.text_model.final_layer_norm.parameters(), 90 text_encoder.text_model.final_layer_norm.parameters(),
91 ), 91 ),
92 max_grad_norm 92 max_grad_norm,
93 ) 93 )
94 94
95 if len(placeholder_tokens) != 0 and use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
@@ -108,7 +108,9 @@ def lora_strategy_callbacks(
108 108
109 if lambda_ != 0: 109 if lambda_ != 0:
110 norm = w[:, :].norm(dim=-1, keepdim=True) 110 norm = w[:, :].norm(dim=-1, keepdim=True)
111 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 111 w[:].add_(
112 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
113 )
112 114
113 @torch.no_grad() 115 @torch.no_grad()
114 def on_checkpoint(step, postfix): 116 def on_checkpoint(step, postfix):
@@ -128,25 +130,32 @@ def lora_strategy_callbacks(
128 130
129 if not pti_mode: 131 if not pti_mode:
130 lora_config = {} 132 lora_config = {}
131 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 133 state_dict = get_peft_model_state_dict(
134 unet_, state_dict=accelerator.get_state_dict(unet_)
135 )
132 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 136 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
133 137
134 text_encoder_state_dict = get_peft_model_state_dict( 138 text_encoder_state_dict = get_peft_model_state_dict(
135 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 139 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
136 ) 140 )
137 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 141 text_encoder_state_dict = {
142 f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()
143 }
138 state_dict.update(text_encoder_state_dict) 144 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 145 lora_config[
146 "text_encoder_peft_config"
147 ] = text_encoder_.get_peft_config_as_dict(inference=True)
140 148
141 if len(placeholder_tokens) != 0: 149 if len(placeholder_tokens) != 0:
142 ti_state_dict = { 150 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) 151 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids) 152 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 } 153 }
147 state_dict.update(ti_state_dict) 154 state_dict.update(ti_state_dict)
148 155
149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 156 save_file(
157 state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors"
158 )
150 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 159 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
151 json.dump(lora_config, f) 160 json.dump(lora_config, f)
152 161
@@ -185,10 +194,18 @@ def lora_prepare(
185 train_dataloader: DataLoader, 194 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 195 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 196 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
188 **kwargs 197 **kwargs,
189): 198):
190 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 199 (
191 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 200 text_encoder,
201 unet,
202 optimizer,
203 train_dataloader,
204 val_dataloader,
205 lr_scheduler,
206 ) = accelerator.prepare(
207 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
208 )
192 209
193 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) 210 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
194 211
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6bc1d7d..7373982 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(
120 text_encoder.text_model.embeddings.token_embedding.parameters()
121 )
120 122
121 if use_emb_decay and w is not None: 123 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"] 124 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks(
124 126
125 if lambda_ != 0: 127 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 128 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 129 w[:].add_(
130 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
131 )
128 132
129 def on_log(): 133 def on_log():
130 if ema_embeddings is not None: 134 if ema_embeddings is not None:
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks(
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 with ema_context(): 142 with ema_context():
139 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
140 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
141 ids, 145 ids,
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 146 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin",
143 ) 147 )
144 148
145 @torch.no_grad() 149 @torch.no_grad()
@@ -183,7 +187,7 @@ def textual_inversion_prepare(
183 val_dataloader: Optional[DataLoader], 187 val_dataloader: Optional[DataLoader],
184 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 188 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
185 gradient_checkpointing: bool = False, 189 gradient_checkpointing: bool = False,
186 **kwargs 190 **kwargs,
187): 191):
188 weight_dtype = torch.float32 192 weight_dtype = torch.float32
189 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -191,8 +195,15 @@ def textual_inversion_prepare(
191 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
192 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
193 197
194 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 198 (
195 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 199 text_encoder,
200 optimizer,
201 train_dataloader,
202 val_dataloader,
203 lr_scheduler,
204 ) = accelerator.prepare(
205 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
206 )
196 207
197 unet.to(accelerator.device, dtype=weight_dtype) 208 unet.to(accelerator.device, dtype=weight_dtype)
198 unet.requires_grad_(False) 209 unet.requires_grad_(False)