summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py58
1 files changed, 32 insertions, 26 deletions
diff --git a/training/functional.py b/training/functional.py
index 695a24f..3036ed9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -461,6 +461,10 @@ def train_loop(
461 num_epochs: int = 100, 461 num_epochs: int = 100,
462 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
463 group_labels: list[str] = [], 463 group_labels: list[str] = [],
464 avg_loss: AverageMeter = AverageMeter(),
465 avg_acc: AverageMeter = AverageMeter(),
466 avg_loss_val: AverageMeter = AverageMeter(),
467 avg_acc_val: AverageMeter = AverageMeter(),
464 callbacks: TrainingCallbacks = TrainingCallbacks(), 468 callbacks: TrainingCallbacks = TrainingCallbacks(),
465): 469):
466 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
@@ -472,14 +476,8 @@ def train_loop(
472 global_step = 0 476 global_step = 0
473 cache = {} 477 cache = {}
474 478
475 avg_loss = AverageMeter() 479 best_acc = avg_acc.avg
476 avg_acc = AverageMeter() 480 best_acc_val = avg_acc_val.avg
477
478 avg_loss_val = AverageMeter()
479 avg_acc_val = AverageMeter()
480
481 best_acc = 0.0
482 best_acc_val = 0.0
483 481
484 local_progress_bar = tqdm( 482 local_progress_bar = tqdm(
485 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 483 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -544,12 +542,12 @@ def train_loop(
544 542
545 accelerator.backward(loss) 543 accelerator.backward(loss)
546 544
547 avg_loss.update(loss.detach_(), bsz) 545 avg_loss.update(loss.item(), bsz)
548 avg_acc.update(acc.detach_(), bsz) 546 avg_acc.update(acc.item(), bsz)
549 547
550 logs = { 548 logs = {
551 "train/loss": avg_loss.avg.item(), 549 "train/loss": avg_loss.avg,
552 "train/acc": avg_acc.avg.item(), 550 "train/acc": avg_acc.avg,
553 "train/cur_loss": loss.item(), 551 "train/cur_loss": loss.item(),
554 "train/cur_acc": acc.item(), 552 "train/cur_acc": acc.item(),
555 } 553 }
@@ -603,47 +601,47 @@ def train_loop(
603 loss = loss.detach_() 601 loss = loss.detach_()
604 acc = acc.detach_() 602 acc = acc.detach_()
605 603
606 cur_loss_val.update(loss, bsz) 604 cur_loss_val.update(loss.item(), bsz)
607 cur_acc_val.update(acc, bsz) 605 cur_acc_val.update(acc.item(), bsz)
608 606
609 avg_loss_val.update(loss, bsz) 607 avg_loss_val.update(loss.item(), bsz)
610 avg_acc_val.update(acc, bsz) 608 avg_acc_val.update(acc.item(), bsz)
611 609
612 local_progress_bar.update(1) 610 local_progress_bar.update(1)
613 global_progress_bar.update(1) 611 global_progress_bar.update(1)
614 612
615 logs = { 613 logs = {
616 "val/loss": avg_loss_val.avg.item(), 614 "val/loss": avg_loss_val.avg,
617 "val/acc": avg_acc_val.avg.item(), 615 "val/acc": avg_acc_val.avg,
618 "val/cur_loss": loss.item(), 616 "val/cur_loss": loss.item(),
619 "val/cur_acc": acc.item(), 617 "val/cur_acc": acc.item(),
620 } 618 }
621 local_progress_bar.set_postfix(**logs) 619 local_progress_bar.set_postfix(**logs)
622 620
623 logs["val/cur_loss"] = cur_loss_val.avg.item() 621 logs["val/cur_loss"] = cur_loss_val.avg
624 logs["val/cur_acc"] = cur_acc_val.avg.item() 622 logs["val/cur_acc"] = cur_acc_val.avg
625 623
626 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
627 625
628 if accelerator.is_main_process: 626 if accelerator.is_main_process:
629 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: 627 if avg_acc_val.avg > best_acc_val and milestone_checkpoints:
630 local_progress_bar.clear() 628 local_progress_bar.clear()
631 global_progress_bar.clear() 629 global_progress_bar.clear()
632 630
633 accelerator.print( 631 accelerator.print(
634 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}")
635 on_checkpoint(global_step, "milestone") 633 on_checkpoint(global_step, "milestone")
636 best_acc_val = avg_acc_val.avg.item() 634 best_acc_val = avg_acc_val.avg
637 else: 635 else:
638 if accelerator.is_main_process: 636 if accelerator.is_main_process:
639 if avg_acc.avg.item() > best_acc and milestone_checkpoints: 637 if avg_acc.avg > best_acc and milestone_checkpoints:
640 local_progress_bar.clear() 638 local_progress_bar.clear()
641 global_progress_bar.clear() 639 global_progress_bar.clear()
642 640
643 accelerator.print( 641 accelerator.print(
644 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}")
645 on_checkpoint(global_step, "milestone") 643 on_checkpoint(global_step, "milestone")
646 best_acc = avg_acc.avg.item() 644 best_acc = avg_acc.avg
647 645
648 # Create the pipeline using using the trained modules and save it. 646 # Create the pipeline using using the trained modules and save it.
649 if accelerator.is_main_process: 647 if accelerator.is_main_process:
@@ -688,6 +686,10 @@ def train(
688 offset_noise_strength: float = 0.15, 686 offset_noise_strength: float = 0.15,
689 disc: Optional[ConvNeXtDiscriminator] = None, 687 disc: Optional[ConvNeXtDiscriminator] = None,
690 min_snr_gamma: int = 5, 688 min_snr_gamma: int = 5,
689 avg_loss: AverageMeter = AverageMeter(),
690 avg_acc: AverageMeter = AverageMeter(),
691 avg_loss_val: AverageMeter = AverageMeter(),
692 avg_acc_val: AverageMeter = AverageMeter(),
691 **kwargs, 693 **kwargs,
692): 694):
693 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 695 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
@@ -737,6 +739,10 @@ def train(
737 num_epochs=num_train_epochs, 739 num_epochs=num_train_epochs,
738 gradient_accumulation_steps=gradient_accumulation_steps, 740 gradient_accumulation_steps=gradient_accumulation_steps,
739 group_labels=group_labels, 741 group_labels=group_labels,
742 avg_loss=avg_loss,
743 avg_acc=avg_acc,
744 avg_loss_val=avg_loss_val,
745 avg_acc_val=avg_acc_val,
740 callbacks=callbacks, 746 callbacks=callbacks,
741 ) 747 )
742 748