summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py101
1 files changed, 56 insertions, 45 deletions
diff --git a/training/functional.py b/training/functional.py
index 1b6162b..c6b4dc3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def save_samples(
73 vae: AutoencoderKL, 73 vae: AutoencoderKL,
74 sample_scheduler: DPMSolverMultistepScheduler, 74 sample_scheduler: DPMSolverMultistepScheduler,
75 train_dataloader: DataLoader, 75 train_dataloader: DataLoader,
76 val_dataloader: DataLoader, 76 val_dataloader: Optional[DataLoader],
77 dtype: torch.dtype, 77 dtype: torch.dtype,
78 output_dir: Path, 78 output_dir: Path,
79 seed: int, 79 seed: int,
@@ -111,11 +111,13 @@ def save_samples(
111 111
112 generator = torch.Generator(device=accelerator.device).manual_seed(seed) 112 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
113 113
114 for pool, data, gen in [ 114 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)]
115 ("stable", val_dataloader, generator), 115
116 ("val", val_dataloader, None), 116 if val_dataloader is not None:
117 ("train", train_dataloader, None) 117 datasets.append(("stable", val_dataloader, generator))
118 ]: 118 datasets.append(("val", val_dataloader, None))
119
120 for pool, data, gen in datasets:
119 all_samples = [] 121 all_samples = []
120 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 122 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
121 file_path.parent.mkdir(parents=True, exist_ok=True) 123 file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -328,7 +330,7 @@ def train_loop(
328 optimizer: torch.optim.Optimizer, 330 optimizer: torch.optim.Optimizer,
329 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 331 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
330 train_dataloader: DataLoader, 332 train_dataloader: DataLoader,
331 val_dataloader: DataLoader, 333 val_dataloader: Optional[DataLoader],
332 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 334 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
333 sample_frequency: int = 10, 335 sample_frequency: int = 10,
334 checkpoint_frequency: int = 50, 336 checkpoint_frequency: int = 50,
@@ -337,7 +339,7 @@ def train_loop(
337 callbacks: TrainingCallbacks = TrainingCallbacks(), 339 callbacks: TrainingCallbacks = TrainingCallbacks(),
338): 340):
339 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 341 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
340 num_val_steps_per_epoch = len(val_dataloader) 342 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0
341 343
342 num_training_steps = num_training_steps_per_epoch * num_epochs 344 num_training_steps = num_training_steps_per_epoch * num_epochs
343 num_val_steps = num_val_steps_per_epoch * num_epochs 345 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -350,6 +352,7 @@ def train_loop(
350 avg_loss_val = AverageMeter() 352 avg_loss_val = AverageMeter()
351 avg_acc_val = AverageMeter() 353 avg_acc_val = AverageMeter()
352 354
355 max_acc = 0.0
353 max_acc_val = 0.0 356 max_acc_val = 0.0
354 357
355 local_progress_bar = tqdm( 358 local_progress_bar = tqdm(
@@ -432,49 +435,57 @@ def train_loop(
432 435
433 accelerator.wait_for_everyone() 436 accelerator.wait_for_everyone()
434 437
435 model.eval() 438 if val_dataloader is not None:
436 439 model.eval()
437 cur_loss_val = AverageMeter()
438 cur_acc_val = AverageMeter()
439
440 with torch.inference_mode(), on_eval():
441 for step, batch in enumerate(val_dataloader):
442 loss, acc, bsz = loss_step(step, batch, True)
443
444 loss = loss.detach_()
445 acc = acc.detach_()
446
447 cur_loss_val.update(loss, bsz)
448 cur_acc_val.update(acc, bsz)
449 440
450 avg_loss_val.update(loss, bsz) 441 cur_loss_val = AverageMeter()
451 avg_acc_val.update(acc, bsz) 442 cur_acc_val = AverageMeter()
452 443
453 local_progress_bar.update(1) 444 with torch.inference_mode(), on_eval():
454 global_progress_bar.update(1) 445 for step, batch in enumerate(val_dataloader):
446 loss, acc, bsz = loss_step(step, batch, True)
455 447
456 logs = { 448 loss = loss.detach_()
457 "val/loss": avg_loss_val.avg.item(), 449 acc = acc.detach_()
458 "val/acc": avg_acc_val.avg.item(),
459 "val/cur_loss": loss.item(),
460 "val/cur_acc": acc.item(),
461 }
462 local_progress_bar.set_postfix(**logs)
463 450
464 logs["val/cur_loss"] = cur_loss_val.avg.item() 451 cur_loss_val.update(loss, bsz)
465 logs["val/cur_acc"] = cur_acc_val.avg.item() 452 cur_acc_val.update(acc, bsz)
466 453
467 accelerator.log(logs, step=global_step) 454 avg_loss_val.update(loss, bsz)
455 avg_acc_val.update(acc, bsz)
468 456
469 local_progress_bar.clear() 457 local_progress_bar.update(1)
470 global_progress_bar.clear() 458 global_progress_bar.update(1)
471 459
472 if accelerator.is_main_process: 460 logs = {
473 if avg_acc_val.avg.item() > max_acc_val: 461 "val/loss": avg_loss_val.avg.item(),
474 accelerator.print( 462 "val/acc": avg_acc_val.avg.item(),
475 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 463 "val/cur_loss": loss.item(),
476 on_checkpoint(global_step + global_step_offset, "milestone") 464 "val/cur_acc": acc.item(),
477 max_acc_val = avg_acc_val.avg.item() 465 }
466 local_progress_bar.set_postfix(**logs)
467
468 logs["val/cur_loss"] = cur_loss_val.avg.item()
469 logs["val/cur_acc"] = cur_acc_val.avg.item()
470
471 accelerator.log(logs, step=global_step)
472
473 local_progress_bar.clear()
474 global_progress_bar.clear()
475
476 if accelerator.is_main_process:
477 if avg_acc_val.avg.item() > max_acc_val:
478 accelerator.print(
479 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
480 on_checkpoint(global_step + global_step_offset, "milestone")
481 max_acc_val = avg_acc_val.avg.item()
482 else:
483 if accelerator.is_main_process:
484 if avg_acc.avg.item() > max_acc:
485 accelerator.print(
486 f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}")
487 on_checkpoint(global_step + global_step_offset, "milestone")
488 max_acc = avg_acc.avg.item()
478 489
479 # Create the pipeline using using the trained modules and save it. 490 # Create the pipeline using using the trained modules and save it.
480 if accelerator.is_main_process: 491 if accelerator.is_main_process:
@@ -499,7 +510,7 @@ def train(
499 seed: int, 510 seed: int,
500 project: str, 511 project: str,
501 train_dataloader: DataLoader, 512 train_dataloader: DataLoader,
502 val_dataloader: DataLoader, 513 val_dataloader: Optional[DataLoader],
503 optimizer: torch.optim.Optimizer, 514 optimizer: torch.optim.Optimizer,
504 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 515 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
505 callbacks_fn: Callable[..., TrainingCallbacks], 516 callbacks_fn: Callable[..., TrainingCallbacks],