summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py22
1 files changed, 10 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py
index 4220c79..2dcfbb8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -476,6 +476,9 @@ def train_loop(
476 except ImportError: 476 except ImportError:
477 pass 477 pass
478 478
479 num_training_steps += global_step_offset
480 global_step += global_step_offset
481
479 try: 482 try:
480 for epoch in range(num_epochs): 483 for epoch in range(num_epochs):
481 if accelerator.is_main_process: 484 if accelerator.is_main_process:
@@ -484,13 +487,13 @@ def train_loop(
484 global_progress_bar.clear() 487 global_progress_bar.clear()
485 488
486 with on_eval(): 489 with on_eval():
487 on_sample(global_step + global_step_offset) 490 on_sample(global_step)
488 491
489 if epoch % checkpoint_frequency == 0 and epoch != 0: 492 if epoch % checkpoint_frequency == 0 and epoch != 0:
490 local_progress_bar.clear() 493 local_progress_bar.clear()
491 global_progress_bar.clear() 494 global_progress_bar.clear()
492 495
493 on_checkpoint(global_step + global_step_offset, "training") 496 on_checkpoint(global_step, "training")
494 497
495 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 498 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
496 local_progress_bar.reset() 499 local_progress_bar.reset()
@@ -592,7 +595,7 @@ def train_loop(
592 595
593 accelerator.print( 596 accelerator.print(
594 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 597 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
595 on_checkpoint(global_step + global_step_offset, "milestone") 598 on_checkpoint(global_step, "milestone")
596 best_acc_val = avg_acc_val.avg.item() 599 best_acc_val = avg_acc_val.avg.item()
597 else: 600 else:
598 if accelerator.is_main_process: 601 if accelerator.is_main_process:
@@ -602,20 +605,20 @@ def train_loop(
602 605
603 accelerator.print( 606 accelerator.print(
604 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 607 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
605 on_checkpoint(global_step + global_step_offset, "milestone") 608 on_checkpoint(global_step, "milestone")
606 best_acc = avg_acc.avg.item() 609 best_acc = avg_acc.avg.item()
607 610
608 # Create the pipeline using using the trained modules and save it. 611 # Create the pipeline using using the trained modules and save it.
609 if accelerator.is_main_process: 612 if accelerator.is_main_process:
610 print("Finished!") 613 print("Finished!")
611 with on_eval(): 614 with on_eval():
612 on_sample(global_step + global_step_offset) 615 on_sample(global_step)
613 on_checkpoint(global_step + global_step_offset, "end") 616 on_checkpoint(global_step, "end")
614 617
615 except KeyboardInterrupt: 618 except KeyboardInterrupt:
616 if accelerator.is_main_process: 619 if accelerator.is_main_process:
617 print("Interrupted") 620 print("Interrupted")
618 on_checkpoint(global_step + global_step_offset, "end") 621 on_checkpoint(global_step, "end")
619 raise KeyboardInterrupt 622 raise KeyboardInterrupt
620 623
621 624
@@ -627,7 +630,6 @@ def train(
627 noise_scheduler: SchedulerMixin, 630 noise_scheduler: SchedulerMixin,
628 dtype: torch.dtype, 631 dtype: torch.dtype,
629 seed: int, 632 seed: int,
630 project: str,
631 train_dataloader: DataLoader, 633 train_dataloader: DataLoader,
632 val_dataloader: Optional[DataLoader], 634 val_dataloader: Optional[DataLoader],
633 optimizer: torch.optim.Optimizer, 635 optimizer: torch.optim.Optimizer,
@@ -678,9 +680,6 @@ def train(
678 min_snr_gamma, 680 min_snr_gamma,
679 ) 681 )
680 682
681 if accelerator.is_main_process:
682 accelerator.init_trackers(project)
683
684 train_loop( 683 train_loop(
685 accelerator=accelerator, 684 accelerator=accelerator,
686 optimizer=optimizer, 685 optimizer=optimizer,
@@ -701,5 +700,4 @@ def train(
701 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 700 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
702 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 701 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
703 702
704 accelerator.end_training()
705 accelerator.free_memory() 703 accelerator.free_memory()