From 24f7a14defdbc050907ad282ebdc7ea6f6591363 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Apr 2023 09:13:35 +0200 Subject: Added cycle LR decay --- training/functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py @@ -451,6 +451,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, + initial_samples: bool = True, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -513,7 +514,7 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0: + if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() @@ -673,6 +674,7 @@ def train( sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, + initial_samples: bool = True, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, @@ -723,6 +725,7 @@ def train( sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, + initial_samples=initial_samples, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, -- cgit v1.2.3-70-g09d2