From a7dc66ae0974886a6c6a4c50def1b733bc04525a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 17:02:22 +0200 Subject: Update --- environment.yaml | 2 +- train_lora.py | 23 ++++++++++++++++------- train_ti.py | 23 ++++++++++++++++------- training/functional.py | 22 ++++++++++------------ 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/environment.yaml b/environment.yaml index 418cb22..a95df2a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,7 +11,7 @@ dependencies: - python=3.10.8 - pytorch=2.0.0=*cuda11.8* - torchvision=0.15.0 - - xformers=0.0.18.dev498 + - xformers=0.0.18.dev504 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers diff --git a/train_lora.py b/train_lora.py index 0d8ee23..29e40b2 100644 --- a/train_lora.py +++ b/train_lora.py @@ -919,6 +919,8 @@ def main(): args.num_train_steps / len(lora_datamodule.train_dataset) ) * args.gradient_accumulation_steps lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) + num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) @@ -956,15 +958,19 @@ def main(): training_iter = 0 + lora_project = "lora" + + if accelerator.is_main_process: + accelerator.init_trackers(lora_project) + while True: - training_iter += 1 - if training_iter > args.auto_cycles: + if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") if response.lower().strip() == "n": break print("") - print(f"============ LoRA cycle {training_iter} ============") + print(f"============ LoRA cycle {training_iter + 1} ============") print("") lora_optimizer = create_optimizer(params_to_optimize) @@ -976,19 +982,18 @@ def main(): train_epochs=num_train_epochs, ) - lora_project = f"lora_{training_iter}" - lora_checkpoint_output_dir = output_dir / lora_project / "model" - lora_sample_output_dir = output_dir / lora_project / "samples" + lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" + lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" trainer( strategy=lora_strategy, - project=lora_project, train_dataloader=lora_datamodule.train_dataloader, val_dataloader=lora_datamodule.val_dataloader, optimizer=lora_optimizer, lr_scheduler=lora_lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + global_step_offset=training_iter * num_train_steps, # -- group_labels=group_labels, sample_output_dir=lora_sample_output_dir, @@ -996,6 +1001,10 @@ def main(): sample_frequency=lora_sample_frequency, ) + training_iter += 1 + + accelerator.end_training() + if __name__ == "__main__": main() diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -870,20 +870,26 @@ def main(): args.num_train_steps / len(datamodule.train_dataset) ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) + num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: sample_frequency = math.ceil(num_train_epochs / args.sample_num) training_iter = 0 + project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" + + if accelerator.is_main_process: + accelerator.init_trackers(project) + while True: - training_iter += 1 - if training_iter > args.auto_cycles: + if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") if response.lower().strip() == "n": break print("") - print(f"------------ TI cycle {training_iter} ------------") + print(f"------------ TI cycle {training_iter + 1} ------------") print("") optimizer = create_optimizer( @@ -908,17 +914,16 @@ def main(): mid_point=args.lr_mid_point, ) - project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" - sample_output_dir = output_dir / project / "samples" - checkpoint_output_dir = output_dir / project / "checkpoints" + sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" + checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" trainer( - project=project, train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, + global_step_offset=training_iter * num_train_steps, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, @@ -928,6 +933,10 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) + training_iter += 1 + + accelerator.end_training() + if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -476,6 +476,9 @@ def train_loop( except ImportError: pass + num_training_steps += global_step_offset + global_step += global_step_offset + try: for epoch in range(num_epochs): if accelerator.is_main_process: @@ -484,13 +487,13 @@ def train_loop( global_progress_bar.clear() with on_eval(): - on_sample(global_step + global_step_offset) + on_sample(global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() - on_checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -592,7 +595,7 @@ def train_loop( accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.avg.item() else: if accelerator.is_main_process: @@ -602,20 +605,20 @@ def train_loop( accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step, "milestone") best_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") with on_eval(): - on_sample(global_step + global_step_offset) - on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step) + on_checkpoint(global_step, "end") except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - on_checkpoint(global_step + global_step_offset, "end") + on_checkpoint(global_step, "end") raise KeyboardInterrupt @@ -627,7 +630,6 @@ def train( noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, - project: str, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, @@ -678,9 +680,6 @@ def train( min_snr_gamma, ) - if accelerator.is_main_process: - accelerator.init_trackers(project) - train_loop( accelerator=accelerator, optimizer=optimizer, @@ -701,5 +700,4 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) - accelerator.end_training() accelerator.free_memory() -- cgit v1.2.3-70-g09d2