diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -451,6 +451,7 @@ def train_loop( | |||
451 | sample_frequency: int = 10, | 451 | sample_frequency: int = 10, |
452 | checkpoint_frequency: int = 50, | 452 | checkpoint_frequency: int = 50, |
453 | milestone_checkpoints: bool = True, | 453 | milestone_checkpoints: bool = True, |
454 | initial_samples: bool = True, | ||
454 | global_step_offset: int = 0, | 455 | global_step_offset: int = 0, |
455 | num_epochs: int = 100, | 456 | num_epochs: int = 100, |
456 | gradient_accumulation_steps: int = 1, | 457 | gradient_accumulation_steps: int = 1, |
@@ -513,7 +514,7 @@ def train_loop( | |||
513 | try: | 514 | try: |
514 | for epoch in range(num_epochs): | 515 | for epoch in range(num_epochs): |
515 | if accelerator.is_main_process: | 516 | if accelerator.is_main_process: |
516 | if epoch % sample_frequency == 0: | 517 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): |
517 | local_progress_bar.clear() | 518 | local_progress_bar.clear() |
518 | global_progress_bar.clear() | 519 | global_progress_bar.clear() |
519 | 520 | ||
@@ -673,6 +674,7 @@ def train( | |||
673 | sample_frequency: int = 20, | 674 | sample_frequency: int = 20, |
674 | checkpoint_frequency: int = 50, | 675 | checkpoint_frequency: int = 50, |
675 | milestone_checkpoints: bool = True, | 676 | milestone_checkpoints: bool = True, |
677 | initial_samples: bool = True, | ||
676 | global_step_offset: int = 0, | 678 | global_step_offset: int = 0, |
677 | guidance_scale: float = 0.0, | 679 | guidance_scale: float = 0.0, |
678 | prior_loss_weight: float = 1.0, | 680 | prior_loss_weight: float = 1.0, |
@@ -723,6 +725,7 @@ def train( | |||
723 | sample_frequency=sample_frequency, | 725 | sample_frequency=sample_frequency, |
724 | checkpoint_frequency=checkpoint_frequency, | 726 | checkpoint_frequency=checkpoint_frequency, |
725 | milestone_checkpoints=milestone_checkpoints, | 727 | milestone_checkpoints=milestone_checkpoints, |
728 | initial_samples=initial_samples, | ||
726 | global_step_offset=global_step_offset, | 729 | global_step_offset=global_step_offset, |
727 | num_epochs=num_train_epochs, | 730 | num_epochs=num_train_epochs, |
728 | gradient_accumulation_steps=gradient_accumulation_steps, | 731 | gradient_accumulation_steps=gradient_accumulation_steps, |