diff options
-rw-r--r-- | train_lora.py | 12 | ||||
-rw-r--r-- | train_ti.py | 8 | ||||
-rw-r--r-- | training/functional.py | 5 |
3 files changed, 16 insertions, 9 deletions
diff --git a/train_lora.py b/train_lora.py index 073e939..b8c7396 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -945,6 +945,8 @@ def main(): | |||
945 | if accelerator.is_main_process: | 945 | if accelerator.is_main_process: |
946 | accelerator.init_trackers(lora_project) | 946 | accelerator.init_trackers(lora_project) |
947 | 947 | ||
948 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
949 | |||
948 | while True: | 950 | while True: |
949 | if training_iter >= args.auto_cycles: | 951 | if training_iter >= args.auto_cycles: |
950 | response = input("Run another cycle? [y/n] ") | 952 | response = input("Run another cycle? [y/n] ") |
@@ -995,8 +997,7 @@ def main(): | |||
995 | train_epochs=num_train_epochs, | 997 | train_epochs=num_train_epochs, |
996 | ) | 998 | ) |
997 | 999 | ||
998 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" | 1000 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
999 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" | ||
1000 | 1001 | ||
1001 | trainer( | 1002 | trainer( |
1002 | strategy=lora_strategy, | 1003 | strategy=lora_strategy, |
@@ -1007,6 +1008,7 @@ def main(): | |||
1007 | num_train_epochs=num_train_epochs, | 1008 | num_train_epochs=num_train_epochs, |
1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1009 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1009 | global_step_offset=training_iter * num_train_steps, | 1010 | global_step_offset=training_iter * num_train_steps, |
1011 | initial_samples=training_iter == 0, | ||
1010 | # -- | 1012 | # -- |
1011 | group_labels=group_labels, | 1013 | group_labels=group_labels, |
1012 | sample_output_dir=lora_sample_output_dir, | 1014 | sample_output_dir=lora_sample_output_dir, |
@@ -1015,11 +1017,11 @@ def main(): | |||
1015 | ) | 1017 | ) |
1016 | 1018 | ||
1017 | training_iter += 1 | 1019 | training_iter += 1 |
1018 | if args.learning_rate_emb is not None: | 1020 | if learning_rate_emb is not None: |
1019 | learning_rate_emb *= args.cycle_decay | 1021 | learning_rate_emb *= args.cycle_decay |
1020 | if args.learning_rate_unet is not None: | 1022 | if learning_rate_unet is not None: |
1021 | learning_rate_unet *= args.cycle_decay | 1023 | learning_rate_unet *= args.cycle_decay |
1022 | if args.learning_rate_text is not None: | 1024 | if learning_rate_text is not None: |
1023 | learning_rate_text *= args.cycle_decay | 1025 | learning_rate_text *= args.cycle_decay |
1024 | 1026 | ||
1025 | accelerator.end_training() | 1027 | accelerator.end_training() |
diff --git a/train_ti.py b/train_ti.py index 94ddbb6..d931db6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -901,6 +901,8 @@ def main(): | |||
901 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
902 | accelerator.init_trackers(project) | 902 | accelerator.init_trackers(project) |
903 | 903 | ||
904 | sample_output_dir = output_dir / project / "samples" | ||
905 | |||
904 | while True: | 906 | while True: |
905 | if training_iter >= args.auto_cycles: | 907 | if training_iter >= args.auto_cycles: |
906 | response = input("Run another cycle? [y/n] ") | 908 | response = input("Run another cycle? [y/n] ") |
@@ -933,8 +935,7 @@ def main(): | |||
933 | mid_point=args.lr_mid_point, | 935 | mid_point=args.lr_mid_point, |
934 | ) | 936 | ) |
935 | 937 | ||
936 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" | 938 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" |
937 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" | ||
938 | 939 | ||
939 | trainer( | 940 | trainer( |
940 | train_dataloader=datamodule.train_dataloader, | 941 | train_dataloader=datamodule.train_dataloader, |
@@ -943,6 +944,7 @@ def main(): | |||
943 | lr_scheduler=lr_scheduler, | 944 | lr_scheduler=lr_scheduler, |
944 | num_train_epochs=num_train_epochs, | 945 | num_train_epochs=num_train_epochs, |
945 | global_step_offset=training_iter * num_train_steps, | 946 | global_step_offset=training_iter * num_train_steps, |
947 | initial_samples=training_iter == 0, | ||
946 | # -- | 948 | # -- |
947 | group_labels=["emb"], | 949 | group_labels=["emb"], |
948 | checkpoint_output_dir=checkpoint_output_dir, | 950 | checkpoint_output_dir=checkpoint_output_dir, |
@@ -953,7 +955,7 @@ def main(): | |||
953 | ) | 955 | ) |
954 | 956 | ||
955 | training_iter += 1 | 957 | training_iter += 1 |
956 | if args.learning_rate is not None: | 958 | if learning_rate is not None: |
957 | learning_rate *= args.cycle_decay | 959 | learning_rate *= args.cycle_decay |
958 | 960 | ||
959 | accelerator.end_training() | 961 | accelerator.end_training() |
diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -451,6 +451,7 @@ def train_loop( | |||
451 | sample_frequency: int = 10, | 451 | sample_frequency: int = 10, |
452 | checkpoint_frequency: int = 50, | 452 | checkpoint_frequency: int = 50, |
453 | milestone_checkpoints: bool = True, | 453 | milestone_checkpoints: bool = True, |
454 | initial_samples: bool = True, | ||
454 | global_step_offset: int = 0, | 455 | global_step_offset: int = 0, |
455 | num_epochs: int = 100, | 456 | num_epochs: int = 100, |
456 | gradient_accumulation_steps: int = 1, | 457 | gradient_accumulation_steps: int = 1, |
@@ -513,7 +514,7 @@ def train_loop( | |||
513 | try: | 514 | try: |
514 | for epoch in range(num_epochs): | 515 | for epoch in range(num_epochs): |
515 | if accelerator.is_main_process: | 516 | if accelerator.is_main_process: |
516 | if epoch % sample_frequency == 0: | 517 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): |
517 | local_progress_bar.clear() | 518 | local_progress_bar.clear() |
518 | global_progress_bar.clear() | 519 | global_progress_bar.clear() |
519 | 520 | ||
@@ -673,6 +674,7 @@ def train( | |||
673 | sample_frequency: int = 20, | 674 | sample_frequency: int = 20, |
674 | checkpoint_frequency: int = 50, | 675 | checkpoint_frequency: int = 50, |
675 | milestone_checkpoints: bool = True, | 676 | milestone_checkpoints: bool = True, |
677 | initial_samples: bool = True, | ||
676 | global_step_offset: int = 0, | 678 | global_step_offset: int = 0, |
677 | guidance_scale: float = 0.0, | 679 | guidance_scale: float = 0.0, |
678 | prior_loss_weight: float = 1.0, | 680 | prior_loss_weight: float = 1.0, |
@@ -723,6 +725,7 @@ def train( | |||
723 | sample_frequency=sample_frequency, | 725 | sample_frequency=sample_frequency, |
724 | checkpoint_frequency=checkpoint_frequency, | 726 | checkpoint_frequency=checkpoint_frequency, |
725 | milestone_checkpoints=milestone_checkpoints, | 727 | milestone_checkpoints=milestone_checkpoints, |
728 | initial_samples=initial_samples, | ||
726 | global_step_offset=global_step_offset, | 729 | global_step_offset=global_step_offset, |
727 | num_epochs=num_train_epochs, | 730 | num_epochs=num_train_epochs, |
728 | gradient_accumulation_steps=gradient_accumulation_steps, | 731 | gradient_accumulation_steps=gradient_accumulation_steps, |