From 71f4a40bb48be4f2759ba2d83faff39691cb2955 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 19:03:25 +0200 Subject: Improved automation caps --- training/functional.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py @@ -42,7 +42,7 @@ class TrainingCallbacks(): on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) - on_sample: Callable[[int], None] = const() + on_sample: Callable[[int, int], None] = const() on_checkpoint: Callable[[int, str], None] = const() @@ -96,6 +96,7 @@ def save_samples( output_dir: Path, seed: int, step: int, + cycle: int = 1, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, @@ -125,7 +126,7 @@ def save_samples( for pool, data, gen in datasets: all_samples = [] - file_path = output_dir / pool / f"step_{step}.jpg" + file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) @@ -455,7 +456,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - initial_samples: bool = True, + cycle: int = 1, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -518,12 +519,12 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): + if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() with on_eval(): - on_sample(global_step) + on_sample(cycle, global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() @@ -648,7 +649,7 @@ def train_loop( if accelerator.is_main_process: print("Finished!") with on_eval(): - on_sample(global_step) + on_sample(cycle, global_step) on_checkpoint(global_step, "end") except KeyboardInterrupt: @@ -680,7 +681,7 @@ def train( sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - initial_samples: bool = True, + cycle: int = 1, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, @@ -731,7 +732,7 @@ def train( sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, - initial_samples=initial_samples, + cycle=cycle, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, -- cgit v1.2.3-54-g00ecf