From a7dc66ae0974886a6c6a4c50def1b733bc04525a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 17:02:22 +0200 Subject: Update --- training/functional.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -476,6 +476,9 @@ def train_loop( except ImportError: pass + num_training_steps += global_step_offset + global_step += global_step_offset + try: for epoch in range(num_epochs): if accelerator.is_main_process: @@ -484,13 +487,13 @@ def train_loop( global_progress_bar.clear() with on_eval(): - on_sample(global_step + global_step_offset) + on_sample(global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() - on_checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -592,7 +595,7 @@ def train_loop( accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.avg.item() else: if accelerator.is_main_process: @@ -602,20 +605,20 @@ def train_loop( accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step, "milestone") best_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") with on_eval(): - on_sample(global_step + global_step_offset) - on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step) + on_checkpoint(global_step, "end") except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - on_checkpoint(global_step + global_step_offset, "end") + on_checkpoint(global_step, "end") raise KeyboardInterrupt @@ -627,7 +630,6 @@ def train( noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, - project: str, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, @@ -678,9 +680,6 @@ def train( min_snr_gamma, ) - if accelerator.is_main_process: - accelerator.init_trackers(project) - train_loop( accelerator=accelerator, optimizer=optimizer, @@ -701,5 +700,4 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) - accelerator.end_training() accelerator.free_memory() -- cgit v1.2.3-70-g09d2