diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 58 |
1 files changed, 32 insertions, 26 deletions
diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -461,6 +461,10 @@ def train_loop( | |||
461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
463 | group_labels: list[str] = [], | 463 | group_labels: list[str] = [], |
464 | avg_loss: AverageMeter = AverageMeter(), | ||
465 | avg_acc: AverageMeter = AverageMeter(), | ||
466 | avg_loss_val: AverageMeter = AverageMeter(), | ||
467 | avg_acc_val: AverageMeter = AverageMeter(), | ||
464 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
465 | ): | 469 | ): |
466 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
@@ -472,14 +476,8 @@ def train_loop( | |||
472 | global_step = 0 | 476 | global_step = 0 |
473 | cache = {} | 477 | cache = {} |
474 | 478 | ||
475 | avg_loss = AverageMeter() | 479 | best_acc = avg_acc.avg |
476 | avg_acc = AverageMeter() | 480 | best_acc_val = avg_acc_val.avg |
477 | |||
478 | avg_loss_val = AverageMeter() | ||
479 | avg_acc_val = AverageMeter() | ||
480 | |||
481 | best_acc = 0.0 | ||
482 | best_acc_val = 0.0 | ||
483 | 481 | ||
484 | local_progress_bar = tqdm( | 482 | local_progress_bar = tqdm( |
485 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -544,12 +542,12 @@ def train_loop( | |||
544 | 542 | ||
545 | accelerator.backward(loss) | 543 | accelerator.backward(loss) |
546 | 544 | ||
547 | avg_loss.update(loss.detach_(), bsz) | 545 | avg_loss.update(loss.item(), bsz) |
548 | avg_acc.update(acc.detach_(), bsz) | 546 | avg_acc.update(acc.item(), bsz) |
549 | 547 | ||
550 | logs = { | 548 | logs = { |
551 | "train/loss": avg_loss.avg.item(), | 549 | "train/loss": avg_loss.avg, |
552 | "train/acc": avg_acc.avg.item(), | 550 | "train/acc": avg_acc.avg, |
553 | "train/cur_loss": loss.item(), | 551 | "train/cur_loss": loss.item(), |
554 | "train/cur_acc": acc.item(), | 552 | "train/cur_acc": acc.item(), |
555 | } | 553 | } |
@@ -603,47 +601,47 @@ def train_loop( | |||
603 | loss = loss.detach_() | 601 | loss = loss.detach_() |
604 | acc = acc.detach_() | 602 | acc = acc.detach_() |
605 | 603 | ||
606 | cur_loss_val.update(loss, bsz) | 604 | cur_loss_val.update(loss.item(), bsz) |
607 | cur_acc_val.update(acc, bsz) | 605 | cur_acc_val.update(acc.item(), bsz) |
608 | 606 | ||
609 | avg_loss_val.update(loss, bsz) | 607 | avg_loss_val.update(loss.item(), bsz) |
610 | avg_acc_val.update(acc, bsz) | 608 | avg_acc_val.update(acc.item(), bsz) |
611 | 609 | ||
612 | local_progress_bar.update(1) | 610 | local_progress_bar.update(1) |
613 | global_progress_bar.update(1) | 611 | global_progress_bar.update(1) |
614 | 612 | ||
615 | logs = { | 613 | logs = { |
616 | "val/loss": avg_loss_val.avg.item(), | 614 | "val/loss": avg_loss_val.avg, |
617 | "val/acc": avg_acc_val.avg.item(), | 615 | "val/acc": avg_acc_val.avg, |
618 | "val/cur_loss": loss.item(), | 616 | "val/cur_loss": loss.item(), |
619 | "val/cur_acc": acc.item(), | 617 | "val/cur_acc": acc.item(), |
620 | } | 618 | } |
621 | local_progress_bar.set_postfix(**logs) | 619 | local_progress_bar.set_postfix(**logs) |
622 | 620 | ||
623 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 621 | logs["val/cur_loss"] = cur_loss_val.avg |
624 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 622 | logs["val/cur_acc"] = cur_acc_val.avg |
625 | 623 | ||
626 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
627 | 625 | ||
628 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
629 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: |
630 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
631 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
632 | 630 | ||
633 | accelerator.print( | 631 | accelerator.print( |
634 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
635 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
636 | best_acc_val = avg_acc_val.avg.item() | 634 | best_acc_val = avg_acc_val.avg |
637 | else: | 635 | else: |
638 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
639 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 637 | if avg_acc.avg > best_acc and milestone_checkpoints: |
640 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
641 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
642 | 640 | ||
643 | accelerator.print( | 641 | accelerator.print( |
644 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
645 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
646 | best_acc = avg_acc.avg.item() | 644 | best_acc = avg_acc.avg |
647 | 645 | ||
648 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
649 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
@@ -688,6 +686,10 @@ def train( | |||
688 | offset_noise_strength: float = 0.15, | 686 | offset_noise_strength: float = 0.15, |
689 | disc: Optional[ConvNeXtDiscriminator] = None, | 687 | disc: Optional[ConvNeXtDiscriminator] = None, |
690 | min_snr_gamma: int = 5, | 688 | min_snr_gamma: int = 5, |
689 | avg_loss: AverageMeter = AverageMeter(), | ||
690 | avg_acc: AverageMeter = AverageMeter(), | ||
691 | avg_loss_val: AverageMeter = AverageMeter(), | ||
692 | avg_acc_val: AverageMeter = AverageMeter(), | ||
691 | **kwargs, | 693 | **kwargs, |
692 | ): | 694 | ): |
693 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 695 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
@@ -737,6 +739,10 @@ def train( | |||
737 | num_epochs=num_train_epochs, | 739 | num_epochs=num_train_epochs, |
738 | gradient_accumulation_steps=gradient_accumulation_steps, | 740 | gradient_accumulation_steps=gradient_accumulation_steps, |
739 | group_labels=group_labels, | 741 | group_labels=group_labels, |
742 | avg_loss=avg_loss, | ||
743 | avg_acc=avg_acc, | ||
744 | avg_loss_val=avg_loss_val, | ||
745 | avg_acc_val=avg_acc_val, | ||
740 | callbacks=callbacks, | 746 | callbacks=callbacks, |
741 | ) | 747 | ) |
742 | 748 | ||