diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -476,6 +476,9 @@ def train_loop( | |||
476 | except ImportError: | 476 | except ImportError: |
477 | pass | 477 | pass |
478 | 478 | ||
479 | num_training_steps += global_step_offset | ||
480 | global_step += global_step_offset | ||
481 | |||
479 | try: | 482 | try: |
480 | for epoch in range(num_epochs): | 483 | for epoch in range(num_epochs): |
481 | if accelerator.is_main_process: | 484 | if accelerator.is_main_process: |
@@ -484,13 +487,13 @@ def train_loop( | |||
484 | global_progress_bar.clear() | 487 | global_progress_bar.clear() |
485 | 488 | ||
486 | with on_eval(): | 489 | with on_eval(): |
487 | on_sample(global_step + global_step_offset) | 490 | on_sample(global_step) |
488 | 491 | ||
489 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 492 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
490 | local_progress_bar.clear() | 493 | local_progress_bar.clear() |
491 | global_progress_bar.clear() | 494 | global_progress_bar.clear() |
492 | 495 | ||
493 | on_checkpoint(global_step + global_step_offset, "training") | 496 | on_checkpoint(global_step, "training") |
494 | 497 | ||
495 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
496 | local_progress_bar.reset() | 499 | local_progress_bar.reset() |
@@ -592,7 +595,7 @@ def train_loop( | |||
592 | 595 | ||
593 | accelerator.print( | 596 | accelerator.print( |
594 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
595 | on_checkpoint(global_step + global_step_offset, "milestone") | 598 | on_checkpoint(global_step, "milestone") |
596 | best_acc_val = avg_acc_val.avg.item() | 599 | best_acc_val = avg_acc_val.avg.item() |
597 | else: | 600 | else: |
598 | if accelerator.is_main_process: | 601 | if accelerator.is_main_process: |
@@ -602,20 +605,20 @@ def train_loop( | |||
602 | 605 | ||
603 | accelerator.print( | 606 | accelerator.print( |
604 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 607 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
605 | on_checkpoint(global_step + global_step_offset, "milestone") | 608 | on_checkpoint(global_step, "milestone") |
606 | best_acc = avg_acc.avg.item() | 609 | best_acc = avg_acc.avg.item() |
607 | 610 | ||
608 | # Create the pipeline using using the trained modules and save it. | 611 | # Create the pipeline using using the trained modules and save it. |
609 | if accelerator.is_main_process: | 612 | if accelerator.is_main_process: |
610 | print("Finished!") | 613 | print("Finished!") |
611 | with on_eval(): | 614 | with on_eval(): |
612 | on_sample(global_step + global_step_offset) | 615 | on_sample(global_step) |
613 | on_checkpoint(global_step + global_step_offset, "end") | 616 | on_checkpoint(global_step, "end") |
614 | 617 | ||
615 | except KeyboardInterrupt: | 618 | except KeyboardInterrupt: |
616 | if accelerator.is_main_process: | 619 | if accelerator.is_main_process: |
617 | print("Interrupted") | 620 | print("Interrupted") |
618 | on_checkpoint(global_step + global_step_offset, "end") | 621 | on_checkpoint(global_step, "end") |
619 | raise KeyboardInterrupt | 622 | raise KeyboardInterrupt |
620 | 623 | ||
621 | 624 | ||
@@ -627,7 +630,6 @@ def train( | |||
627 | noise_scheduler: SchedulerMixin, | 630 | noise_scheduler: SchedulerMixin, |
628 | dtype: torch.dtype, | 631 | dtype: torch.dtype, |
629 | seed: int, | 632 | seed: int, |
630 | project: str, | ||
631 | train_dataloader: DataLoader, | 633 | train_dataloader: DataLoader, |
632 | val_dataloader: Optional[DataLoader], | 634 | val_dataloader: Optional[DataLoader], |
633 | optimizer: torch.optim.Optimizer, | 635 | optimizer: torch.optim.Optimizer, |
@@ -678,9 +680,6 @@ def train( | |||
678 | min_snr_gamma, | 680 | min_snr_gamma, |
679 | ) | 681 | ) |
680 | 682 | ||
681 | if accelerator.is_main_process: | ||
682 | accelerator.init_trackers(project) | ||
683 | |||
684 | train_loop( | 683 | train_loop( |
685 | accelerator=accelerator, | 684 | accelerator=accelerator, |
686 | optimizer=optimizer, | 685 | optimizer=optimizer, |
@@ -701,5 +700,4 @@ def train( | |||
701 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 700 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
702 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 701 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
703 | 702 | ||
704 | accelerator.end_training() | ||
705 | accelerator.free_memory() | 703 | accelerator.free_memory() |