summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index ed8ae3a..54bbe78 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -451,6 +451,7 @@ def train_loop(
451 sample_frequency: int = 10, 451 sample_frequency: int = 10,
452 checkpoint_frequency: int = 50, 452 checkpoint_frequency: int = 50,
453 milestone_checkpoints: bool = True, 453 milestone_checkpoints: bool = True,
454 initial_samples: bool = True,
454 global_step_offset: int = 0, 455 global_step_offset: int = 0,
455 num_epochs: int = 100, 456 num_epochs: int = 100,
456 gradient_accumulation_steps: int = 1, 457 gradient_accumulation_steps: int = 1,
@@ -513,7 +514,7 @@ def train_loop(
513 try: 514 try:
514 for epoch in range(num_epochs): 515 for epoch in range(num_epochs):
515 if accelerator.is_main_process: 516 if accelerator.is_main_process:
516 if epoch % sample_frequency == 0: 517 if epoch % sample_frequency == 0 and (initial_samples or epoch != 0):
517 local_progress_bar.clear() 518 local_progress_bar.clear()
518 global_progress_bar.clear() 519 global_progress_bar.clear()
519 520
@@ -673,6 +674,7 @@ def train(
673 sample_frequency: int = 20, 674 sample_frequency: int = 20,
674 checkpoint_frequency: int = 50, 675 checkpoint_frequency: int = 50,
675 milestone_checkpoints: bool = True, 676 milestone_checkpoints: bool = True,
677 initial_samples: bool = True,
676 global_step_offset: int = 0, 678 global_step_offset: int = 0,
677 guidance_scale: float = 0.0, 679 guidance_scale: float = 0.0,
678 prior_loss_weight: float = 1.0, 680 prior_loss_weight: float = 1.0,
@@ -723,6 +725,7 @@ def train(
723 sample_frequency=sample_frequency, 725 sample_frequency=sample_frequency,
724 checkpoint_frequency=checkpoint_frequency, 726 checkpoint_frequency=checkpoint_frequency,
725 milestone_checkpoints=milestone_checkpoints, 727 milestone_checkpoints=milestone_checkpoints,
728 initial_samples=initial_samples,
726 global_step_offset=global_step_offset, 729 global_step_offset=global_step_offset,
727 num_epochs=num_train_epochs, 730 num_epochs=num_train_epochs,
728 gradient_accumulation_steps=gradient_accumulation_steps, 731 gradient_accumulation_steps=gradient_accumulation_steps,