From 24f7a14defdbc050907ad282ebdc7ea6f6591363 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Apr 2023 09:13:35 +0200 Subject: Added cycle LR decay --- train_lora.py | 12 +++++++----- train_ti.py | 8 +++++--- training/functional.py | 5 ++++- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/train_lora.py b/train_lora.py index 073e939..b8c7396 100644 --- a/train_lora.py +++ b/train_lora.py @@ -945,6 +945,8 @@ def main(): if accelerator.is_main_process: accelerator.init_trackers(lora_project) + lora_sample_output_dir = output_dir / lora_project / "samples" + while True: if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") @@ -995,8 +997,7 @@ def main(): train_epochs=num_train_epochs, ) - lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" - lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" + lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" trainer( strategy=lora_strategy, @@ -1007,6 +1008,7 @@ def main(): num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, global_step_offset=training_iter * num_train_steps, + initial_samples=training_iter == 0, # -- group_labels=group_labels, sample_output_dir=lora_sample_output_dir, @@ -1015,11 +1017,11 @@ def main(): ) training_iter += 1 - if args.learning_rate_emb is not None: + if learning_rate_emb is not None: learning_rate_emb *= args.cycle_decay - if args.learning_rate_unet is not None: + if learning_rate_unet is not None: learning_rate_unet *= args.cycle_decay - if args.learning_rate_text is not None: + if learning_rate_text is not None: learning_rate_text *= args.cycle_decay accelerator.end_training() diff --git a/train_ti.py b/train_ti.py index 94ddbb6..d931db6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -901,6 +901,8 @@ def main(): if accelerator.is_main_process: accelerator.init_trackers(project) + sample_output_dir = output_dir / project / "samples" + while True: if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") @@ -933,8 +935,7 @@ def main(): mid_point=args.lr_mid_point, ) - sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" - checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" + checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" trainer( train_dataloader=datamodule.train_dataloader, @@ -943,6 +944,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, global_step_offset=training_iter * num_train_steps, + initial_samples=training_iter == 0, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, @@ -953,7 +955,7 @@ def main(): ) training_iter += 1 - if args.learning_rate is not None: + if learning_rate is not None: learning_rate *= args.cycle_decay accelerator.end_training() diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py @@ -451,6 +451,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, + initial_samples: bool = True, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -513,7 +514,7 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0: + if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() @@ -673,6 +674,7 @@ def train( sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, + initial_samples: bool = True, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, @@ -723,6 +725,7 @@ def train( sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, + initial_samples=initial_samples, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, -- cgit v1.2.3-70-g09d2