summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py221
1 files changed, 149 insertions, 72 deletions
diff --git a/training/functional.py b/training/functional.py
index fd3f9f4..f68faf9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -14,7 +14,13 @@ import numpy as np
14 14
15from accelerate import Accelerator 15from accelerate import Accelerator
16from transformers import CLIPTextModel 16from transformers import CLIPTextModel
17from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin 17from diffusers import (
18 AutoencoderKL,
19 UNet2DConditionModel,
20 DDPMScheduler,
21 UniPCMultistepScheduler,
22 SchedulerMixin,
23)
18 24
19from tqdm.auto import tqdm 25from tqdm.auto import tqdm
20 26
@@ -33,11 +39,12 @@ from util.noise import perlin_noise
33def const(result=None): 39def const(result=None):
34 def fn(*args, **kwargs): 40 def fn(*args, **kwargs):
35 return result 41 return result
42
36 return fn 43 return fn
37 44
38 45
39@dataclass 46@dataclass
40class TrainingCallbacks(): 47class TrainingCallbacks:
41 on_log: Callable[[], dict[str, Any]] = const({}) 48 on_log: Callable[[], dict[str, Any]] = const({})
42 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 49 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
43 on_before_optimize: Callable[[int], Any] = const() 50 on_before_optimize: Callable[[int], Any] = const()
@@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol):
58 train_dataloader: DataLoader, 65 train_dataloader: DataLoader,
59 val_dataloader: Optional[DataLoader], 66 val_dataloader: Optional[DataLoader],
60 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 67 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
61 **kwargs 68 **kwargs,
62 ) -> Tuple: ... 69 ) -> Tuple:
70 ...
63 71
64 72
65@dataclass 73@dataclass
66class TrainingStrategy(): 74class TrainingStrategy:
67 callbacks: Callable[..., TrainingCallbacks] 75 callbacks: Callable[..., TrainingCallbacks]
68 prepare: TrainingStrategyPrepareCallable 76 prepare: TrainingStrategyPrepareCallable
69 77
70 78
71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): 79def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 80 tokenizer = MultiCLIPTokenizer.from_pretrained(
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) 81 pretrained_model_name_or_path, subfolder="tokenizer"
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) 82 )
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) 83 text_encoder = CLIPTextModel.from_pretrained(
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 85 )
86 vae = AutoencoderKL.from_pretrained(
87 pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype
88 )
89 unet = UNet2DConditionModel.from_pretrained(
90 pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype
91 )
92 noise_scheduler = DDPMScheduler.from_pretrained(
93 pretrained_model_name_or_path, subfolder="scheduler"
94 )
95 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
96 pretrained_model_name_or_path, subfolder="scheduler"
97 )
78 98
79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 99 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
80 100
@@ -113,7 +133,9 @@ def save_samples(
113 133
114 generator = torch.Generator(device=accelerator.device).manual_seed(seed) 134 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
115 135
116 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] 136 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [
137 ("train", train_dataloader, None)
138 ]
117 139
118 if val_dataloader is not None: 140 if val_dataloader is not None:
119 datasets.append(("stable", val_dataloader, generator)) 141 datasets.append(("stable", val_dataloader, generator))
@@ -124,17 +146,11 @@ def save_samples(
124 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" 146 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg"
125 file_path.parent.mkdir(parents=True, exist_ok=True) 147 file_path.parent.mkdir(parents=True, exist_ok=True)
126 148
127 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 149 batches = list(
128 prompt_ids = [ 150 itertools.islice(itertools.cycle(data), batch_size * num_batches)
129 prompt 151 )
130 for batch in batches 152 prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]]
131 for prompt in batch["prompt_ids"] 153 nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]]
132 ]
133 nprompt_ids = [
134 prompt
135 for batch in batches
136 for prompt in batch["nprompt_ids"]
137 ]
138 154
139 with torch.inference_mode(): 155 with torch.inference_mode():
140 for i in range(num_batches): 156 for i in range(num_batches):
@@ -165,7 +181,9 @@ def save_samples(
165 pass 181 pass
166 182
167 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 183 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 184 image_grid = pipeline.numpy_to_pil(
185 image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy()
186 )[0]
169 image_grid.save(file_path, quality=85) 187 image_grid.save(file_path, quality=85)
170 188
171 del generator, pipeline 189 del generator, pipeline
@@ -184,15 +202,17 @@ def generate_class_images(
184 train_dataset: VlpnDataset, 202 train_dataset: VlpnDataset,
185 sample_batch_size: int, 203 sample_batch_size: int,
186 sample_image_size: int, 204 sample_image_size: int,
187 sample_steps: int 205 sample_steps: int,
188): 206):
189 missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] 207 missing_data = [
208 item for item in train_dataset.items if not item.class_image_path.exists()
209 ]
190 210
191 if len(missing_data) == 0: 211 if len(missing_data) == 0:
192 return 212 return
193 213
194 batched_data = [ 214 batched_data = [
195 missing_data[i:i+sample_batch_size] 215 missing_data[i : i + sample_batch_size]
196 for i in range(0, len(missing_data), sample_batch_size) 216 for i in range(0, len(missing_data), sample_batch_size)
197 ] 217 ]
198 218
@@ -216,7 +236,7 @@ def generate_class_images(
216 negative_prompt=nprompt, 236 negative_prompt=nprompt,
217 height=sample_image_size, 237 height=sample_image_size,
218 width=sample_image_size, 238 width=sample_image_size,
219 num_inference_steps=sample_steps 239 num_inference_steps=sample_steps,
220 ).images 240 ).images
221 241
222 for i, image in enumerate(images): 242 for i, image in enumerate(images):
@@ -245,8 +265,12 @@ def add_placeholder_tokens(
245 265
246 embeddings.resize(len(tokenizer)) 266 embeddings.resize(len(tokenizer))
247 267
248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 268 for placeholder_token_id, initializer_token_id in zip(
249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) 269 placeholder_token_ids, initializer_token_ids
270 ):
271 embeddings.add_embed(
272 placeholder_token_id, initializer_token_id, initializer_noise
273 )
250 274
251 return placeholder_token_ids, initializer_token_ids 275 return placeholder_token_ids, initializer_token_ids
252 276
@@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler):
261 285
262 # Expand the tensors. 286 # Expand the tensors.
263 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 287 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
264 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 288 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
289 timesteps
290 ].float()
265 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 291 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
266 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 292 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
267 alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 293 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
268 294
269 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 295 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
296 device=timesteps.device
297 )[timesteps].float()
270 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 298 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
271 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 299 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
272 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 300 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
@@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler):
277 305
278 306
279def get_original( 307def get_original(
280 noise_scheduler, 308 noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor
281 model_output,
282 sample: torch.FloatTensor,
283 timesteps: torch.IntTensor
284): 309):
285 alphas_cumprod = noise_scheduler.alphas_cumprod 310 alphas_cumprod = noise_scheduler.alphas_cumprod
286 sqrt_alphas_cumprod = alphas_cumprod**0.5 311 sqrt_alphas_cumprod = alphas_cumprod**0.5
287 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 312 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
288 313
289 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 314 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
315 timesteps
316 ].float()
290 while len(sqrt_alphas_cumprod.shape) < len(sample.shape): 317 while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
291 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 318 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
292 alpha = sqrt_alphas_cumprod.expand(sample.shape) 319 alpha = sqrt_alphas_cumprod.expand(sample.shape)
293 320
294 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 321 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
322 device=timesteps.device
323 )[timesteps].float()
295 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): 324 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
296 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 325 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
297 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) 326 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
@@ -329,7 +358,9 @@ def loss_step(
329 eval: bool = False, 358 eval: bool = False,
330): 359):
331 images = batch["pixel_values"] 360 images = batch["pixel_values"]
332 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None 361 generator = (
362 torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
363 )
333 bsz = images.shape[0] 364 bsz = images.shape[0]
334 365
335 # Convert images to latent space 366 # Convert images to latent space
@@ -342,7 +373,7 @@ def loss_step(
342 dtype=latents.dtype, 373 dtype=latents.dtype,
343 layout=latents.layout, 374 layout=latents.layout,
344 device=latents.device, 375 device=latents.device,
345 generator=generator 376 generator=generator,
346 ) 377 )
347 applied_noise = noise 378 applied_noise = noise
348 379
@@ -353,7 +384,7 @@ def loss_step(
353 octaves=4, 384 octaves=4,
354 dtype=latents.dtype, 385 dtype=latents.dtype,
355 device=latents.device, 386 device=latents.device,
356 generator=generator 387 generator=generator,
357 ) 388 )
358 389
359 if input_pertubation != 0: 390 if input_pertubation != 0:
@@ -362,7 +393,7 @@ def loss_step(
362 dtype=latents.dtype, 393 dtype=latents.dtype,
363 layout=latents.layout, 394 layout=latents.layout,
364 device=latents.device, 395 device=latents.device,
365 generator=generator 396 generator=generator,
366 ) 397 )
367 398
368 # Sample a random timestep for each image 399 # Sample a random timestep for each image
@@ -375,25 +406,27 @@ def loss_step(
375 406
376 # Get the text embedding for conditioning 407 # Get the text embedding for conditioning
377 encoder_hidden_states = get_extended_embeddings( 408 encoder_hidden_states = get_extended_embeddings(
378 text_encoder, 409 text_encoder, batch["input_ids"], batch["attention_mask"]
379 batch["input_ids"],
380 batch["attention_mask"]
381 ) 410 )
382 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 411 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
383 412
384 # Predict the noise residual 413 # Predict the noise residual
385 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] 414 model_pred = unet(
415 noisy_latents, timesteps, encoder_hidden_states, return_dict=False
416 )[0]
386 417
387 if guidance_scale != 0: 418 if guidance_scale != 0:
388 uncond_encoder_hidden_states = get_extended_embeddings( 419 uncond_encoder_hidden_states = get_extended_embeddings(
389 text_encoder, 420 text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"]
390 batch["negative_input_ids"],
391 batch["negative_attention_mask"]
392 ) 421 )
393 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) 422 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
394 423
395 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] 424 model_pred_uncond = unet(
396 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 425 noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False
426 )[0]
427 model_pred = model_pred_uncond + guidance_scale * (
428 model_pred - model_pred_uncond
429 )
397 430
398 # Get the target for loss depending on the prediction type 431 # Get the target for loss depending on the prediction type
399 if noise_scheduler.config.prediction_type == "epsilon": 432 if noise_scheduler.config.prediction_type == "epsilon":
@@ -401,7 +434,9 @@ def loss_step(
401 elif noise_scheduler.config.prediction_type == "v_prediction": 434 elif noise_scheduler.config.prediction_type == "v_prediction":
402 target = noise_scheduler.get_velocity(latents, noise, timesteps) 435 target = noise_scheduler.get_velocity(latents, noise, timesteps)
403 else: 436 else:
404 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 437 raise ValueError(
438 f"Unknown prediction type {noise_scheduler.config.prediction_type}"
439 )
405 440
406 acc = (model_pred == target).float().mean() 441 acc = (model_pred == target).float().mean()
407 442
@@ -414,7 +449,9 @@ def loss_step(
414 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 449 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
415 450
416 # Compute prior loss 451 # Compute prior loss
417 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 452 prior_loss = F.mse_loss(
453 model_pred_prior.float(), target_prior.float(), reduction="none"
454 )
418 455
419 # Add the prior loss to the instance loss. 456 # Add the prior loss to the instance loss.
420 loss = loss + prior_loss_weight * prior_loss 457 loss = loss + prior_loss_weight * prior_loss
@@ -433,7 +470,10 @@ def loss_step(
433 if min_snr_gamma != 0: 470 if min_snr_gamma != 0:
434 snr = compute_snr(timesteps, noise_scheduler) 471 snr = compute_snr(timesteps, noise_scheduler)
435 mse_loss_weights = ( 472 mse_loss_weights = (
436 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 473 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(
474 dim=1
475 )[0]
476 / snr
437 ) 477 )
438 loss = loss * mse_loss_weights 478 loss = loss * mse_loss_weights
439 479
@@ -447,8 +487,14 @@ def loss_step(
447 487
448 488
449class LossCallable(Protocol): 489class LossCallable(Protocol):
450 def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], 490 def __call__(
451 eval: bool = False) -> Tuple[Any, Any, int]: ... 491 self,
492 step: int,
493 batch: dict[Any, Any],
494 cache: dict[str, Any],
495 eval: bool = False,
496 ) -> Tuple[Any, Any, int]:
497 ...
452 498
453 499
454def train_loop( 500def train_loop(
@@ -472,9 +518,14 @@ def train_loop(
472 avg_acc_val: AverageMeter = AverageMeter(), 518 avg_acc_val: AverageMeter = AverageMeter(),
473 callbacks: TrainingCallbacks = TrainingCallbacks(), 519 callbacks: TrainingCallbacks = TrainingCallbacks(),
474): 520):
475 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 521 num_training_steps_per_epoch = math.ceil(
476 num_val_steps_per_epoch = math.ceil( 522 len(train_dataloader) / gradient_accumulation_steps
477 len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 523 )
524 num_val_steps_per_epoch = (
525 math.ceil(len(val_dataloader) / gradient_accumulation_steps)
526 if val_dataloader is not None
527 else 0
528 )
478 529
479 num_training_steps = num_training_steps_per_epoch * num_epochs 530 num_training_steps = num_training_steps_per_epoch * num_epochs
480 num_val_steps = num_val_steps_per_epoch * num_epochs 531 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -488,14 +539,14 @@ def train_loop(
488 local_progress_bar = tqdm( 539 local_progress_bar = tqdm(
489 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 540 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
490 disable=not accelerator.is_local_main_process, 541 disable=not accelerator.is_local_main_process,
491 dynamic_ncols=True 542 dynamic_ncols=True,
492 ) 543 )
493 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") 544 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
494 545
495 global_progress_bar = tqdm( 546 global_progress_bar = tqdm(
496 range(num_training_steps + num_val_steps), 547 range(num_training_steps + num_val_steps),
497 disable=not accelerator.is_local_main_process, 548 disable=not accelerator.is_local_main_process,
498 dynamic_ncols=True 549 dynamic_ncols=True,
499 ) 550 )
500 global_progress_bar.set_description("Total progress") 551 global_progress_bar.set_description("Total progress")
501 552
@@ -513,7 +564,9 @@ def train_loop(
513 try: 564 try:
514 import dadaptation 565 import dadaptation
515 566
516 isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) 567 isDadaptation = isinstance(
568 optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)
569 )
517 except ImportError: 570 except ImportError:
518 pass 571 pass
519 572
@@ -565,7 +618,10 @@ def train_loop(
565 label = group_labels[i] if i < len(group_labels) else f"{i}" 618 label = group_labels[i] if i < len(group_labels) else f"{i}"
566 logs[f"lr/{label}"] = lr 619 logs[f"lr/{label}"] = lr
567 if isDadaptation: 620 if isDadaptation:
568 lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] 621 lr = (
622 optimizer.param_groups[i]["d"]
623 * optimizer.param_groups[i]["lr"]
624 )
569 logs[f"d*lr/{label}"] = lr 625 logs[f"d*lr/{label}"] = lr
570 lrs[label] = lr 626 lrs[label] = lr
571 627
@@ -573,8 +629,10 @@ def train_loop(
573 629
574 local_progress_bar.set_postfix(**logs) 630 local_progress_bar.set_postfix(**logs)
575 631
576 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 632 if ((step + 1) % gradient_accumulation_steps == 0) or (
577 before_optimize_result = on_before_optimize(epoch) 633 (step + 1) == len(train_dataloader)
634 ):
635 before_optimize_result = on_before_optimize(cycle)
578 636
579 optimizer.step() 637 optimizer.step()
580 lr_scheduler.step() 638 lr_scheduler.step()
@@ -614,7 +672,9 @@ def train_loop(
614 } 672 }
615 local_progress_bar.set_postfix(**logs) 673 local_progress_bar.set_postfix(**logs)
616 674
617 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): 675 if ((step + 1) % gradient_accumulation_steps == 0) or (
676 (step + 1) == len(val_dataloader)
677 ):
618 local_progress_bar.update(1) 678 local_progress_bar.update(1)
619 global_progress_bar.update(1) 679 global_progress_bar.update(1)
620 680
@@ -634,7 +694,8 @@ def train_loop(
634 global_progress_bar.clear() 694 global_progress_bar.clear()
635 695
636 accelerator.print( 696 accelerator.print(
637 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") 697 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}"
698 )
638 on_checkpoint(global_step, "milestone") 699 on_checkpoint(global_step, "milestone")
639 best_acc_val = avg_acc_val.max 700 best_acc_val = avg_acc_val.max
640 else: 701 else:
@@ -644,7 +705,8 @@ def train_loop(
644 global_progress_bar.clear() 705 global_progress_bar.clear()
645 706
646 accelerator.print( 707 accelerator.print(
647 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") 708 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}"
709 )
648 on_checkpoint(global_step, "milestone") 710 on_checkpoint(global_step, "milestone")
649 best_acc = avg_acc.max 711 best_acc = avg_acc.max
650 712
@@ -700,17 +762,32 @@ def train(
700 avg_acc_val: AverageMeter = AverageMeter(), 762 avg_acc_val: AverageMeter = AverageMeter(),
701 **kwargs, 763 **kwargs,
702): 764):
703 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 765 (
704 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 766 text_encoder,
767 unet,
768 optimizer,
769 train_dataloader,
770 val_dataloader,
771 lr_scheduler,
772 ) = strategy.prepare(
773 accelerator,
774 text_encoder,
775 unet,
776 optimizer,
777 train_dataloader,
778 val_dataloader,
779 lr_scheduler,
780 **kwargs,
781 )
705 782
706 vae.to(accelerator.device, dtype=dtype) 783 vae.to(accelerator.device, dtype=dtype)
707 vae.requires_grad_(False) 784 vae.requires_grad_(False)
708 vae.eval() 785 vae.eval()
709 786
710 vae = torch.compile(vae, backend='hidet') 787 vae = torch.compile(vae, backend="hidet")
711 788
712 if compile_unet: 789 if compile_unet:
713 unet = torch.compile(unet, backend='hidet') 790 unet = torch.compile(unet, backend="hidet")
714 # unet = torch.compile(unet, mode="reduce-overhead") 791 # unet = torch.compile(unet, mode="reduce-overhead")
715 792
716 callbacks = strategy.callbacks( 793 callbacks = strategy.callbacks(