diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -476,6 +476,9 @@ def train_loop( | |||
| 476 | except ImportError: | 476 | except ImportError: |
| 477 | pass | 477 | pass |
| 478 | 478 | ||
| 479 | num_training_steps += global_step_offset | ||
| 480 | global_step += global_step_offset | ||
| 481 | |||
| 479 | try: | 482 | try: |
| 480 | for epoch in range(num_epochs): | 483 | for epoch in range(num_epochs): |
| 481 | if accelerator.is_main_process: | 484 | if accelerator.is_main_process: |
| @@ -484,13 +487,13 @@ def train_loop( | |||
| 484 | global_progress_bar.clear() | 487 | global_progress_bar.clear() |
| 485 | 488 | ||
| 486 | with on_eval(): | 489 | with on_eval(): |
| 487 | on_sample(global_step + global_step_offset) | 490 | on_sample(global_step) |
| 488 | 491 | ||
| 489 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 492 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 490 | local_progress_bar.clear() | 493 | local_progress_bar.clear() |
| 491 | global_progress_bar.clear() | 494 | global_progress_bar.clear() |
| 492 | 495 | ||
| 493 | on_checkpoint(global_step + global_step_offset, "training") | 496 | on_checkpoint(global_step, "training") |
| 494 | 497 | ||
| 495 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 496 | local_progress_bar.reset() | 499 | local_progress_bar.reset() |
| @@ -592,7 +595,7 @@ def train_loop( | |||
| 592 | 595 | ||
| 593 | accelerator.print( | 596 | accelerator.print( |
| 594 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 595 | on_checkpoint(global_step + global_step_offset, "milestone") | 598 | on_checkpoint(global_step, "milestone") |
| 596 | best_acc_val = avg_acc_val.avg.item() | 599 | best_acc_val = avg_acc_val.avg.item() |
| 597 | else: | 600 | else: |
| 598 | if accelerator.is_main_process: | 601 | if accelerator.is_main_process: |
| @@ -602,20 +605,20 @@ def train_loop( | |||
| 602 | 605 | ||
| 603 | accelerator.print( | 606 | accelerator.print( |
| 604 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 607 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
| 605 | on_checkpoint(global_step + global_step_offset, "milestone") | 608 | on_checkpoint(global_step, "milestone") |
| 606 | best_acc = avg_acc.avg.item() | 609 | best_acc = avg_acc.avg.item() |
| 607 | 610 | ||
| 608 | # Create the pipeline using using the trained modules and save it. | 611 | # Create the pipeline using using the trained modules and save it. |
| 609 | if accelerator.is_main_process: | 612 | if accelerator.is_main_process: |
| 610 | print("Finished!") | 613 | print("Finished!") |
| 611 | with on_eval(): | 614 | with on_eval(): |
| 612 | on_sample(global_step + global_step_offset) | 615 | on_sample(global_step) |
| 613 | on_checkpoint(global_step + global_step_offset, "end") | 616 | on_checkpoint(global_step, "end") |
| 614 | 617 | ||
| 615 | except KeyboardInterrupt: | 618 | except KeyboardInterrupt: |
| 616 | if accelerator.is_main_process: | 619 | if accelerator.is_main_process: |
| 617 | print("Interrupted") | 620 | print("Interrupted") |
| 618 | on_checkpoint(global_step + global_step_offset, "end") | 621 | on_checkpoint(global_step, "end") |
| 619 | raise KeyboardInterrupt | 622 | raise KeyboardInterrupt |
| 620 | 623 | ||
| 621 | 624 | ||
| @@ -627,7 +630,6 @@ def train( | |||
| 627 | noise_scheduler: SchedulerMixin, | 630 | noise_scheduler: SchedulerMixin, |
| 628 | dtype: torch.dtype, | 631 | dtype: torch.dtype, |
| 629 | seed: int, | 632 | seed: int, |
| 630 | project: str, | ||
| 631 | train_dataloader: DataLoader, | 633 | train_dataloader: DataLoader, |
| 632 | val_dataloader: Optional[DataLoader], | 634 | val_dataloader: Optional[DataLoader], |
| 633 | optimizer: torch.optim.Optimizer, | 635 | optimizer: torch.optim.Optimizer, |
| @@ -678,9 +680,6 @@ def train( | |||
| 678 | min_snr_gamma, | 680 | min_snr_gamma, |
| 679 | ) | 681 | ) |
| 680 | 682 | ||
| 681 | if accelerator.is_main_process: | ||
| 682 | accelerator.init_trackers(project) | ||
| 683 | |||
| 684 | train_loop( | 683 | train_loop( |
| 685 | accelerator=accelerator, | 684 | accelerator=accelerator, |
| 686 | optimizer=optimizer, | 685 | optimizer=optimizer, |
| @@ -701,5 +700,4 @@ def train( | |||
| 701 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 700 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 702 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 701 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 703 | 702 | ||
| 704 | accelerator.end_training() | ||
| 705 | accelerator.free_memory() | 703 | accelerator.free_memory() |
