diff options
| -rw-r--r-- | train_lora.py | 12 | ||||
| -rw-r--r-- | train_ti.py | 8 | ||||
| -rw-r--r-- | training/functional.py | 5 |
3 files changed, 16 insertions, 9 deletions
diff --git a/train_lora.py b/train_lora.py index 073e939..b8c7396 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -945,6 +945,8 @@ def main(): | |||
| 945 | if accelerator.is_main_process: | 945 | if accelerator.is_main_process: |
| 946 | accelerator.init_trackers(lora_project) | 946 | accelerator.init_trackers(lora_project) |
| 947 | 947 | ||
| 948 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
| 949 | |||
| 948 | while True: | 950 | while True: |
| 949 | if training_iter >= args.auto_cycles: | 951 | if training_iter >= args.auto_cycles: |
| 950 | response = input("Run another cycle? [y/n] ") | 952 | response = input("Run another cycle? [y/n] ") |
| @@ -995,8 +997,7 @@ def main(): | |||
| 995 | train_epochs=num_train_epochs, | 997 | train_epochs=num_train_epochs, |
| 996 | ) | 998 | ) |
| 997 | 999 | ||
| 998 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" | 1000 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
| 999 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" | ||
| 1000 | 1001 | ||
| 1001 | trainer( | 1002 | trainer( |
| 1002 | strategy=lora_strategy, | 1003 | strategy=lora_strategy, |
| @@ -1007,6 +1008,7 @@ def main(): | |||
| 1007 | num_train_epochs=num_train_epochs, | 1008 | num_train_epochs=num_train_epochs, |
| 1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1009 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1009 | global_step_offset=training_iter * num_train_steps, | 1010 | global_step_offset=training_iter * num_train_steps, |
| 1011 | initial_samples=training_iter == 0, | ||
| 1010 | # -- | 1012 | # -- |
| 1011 | group_labels=group_labels, | 1013 | group_labels=group_labels, |
| 1012 | sample_output_dir=lora_sample_output_dir, | 1014 | sample_output_dir=lora_sample_output_dir, |
| @@ -1015,11 +1017,11 @@ def main(): | |||
| 1015 | ) | 1017 | ) |
| 1016 | 1018 | ||
| 1017 | training_iter += 1 | 1019 | training_iter += 1 |
| 1018 | if args.learning_rate_emb is not None: | 1020 | if learning_rate_emb is not None: |
| 1019 | learning_rate_emb *= args.cycle_decay | 1021 | learning_rate_emb *= args.cycle_decay |
| 1020 | if args.learning_rate_unet is not None: | 1022 | if learning_rate_unet is not None: |
| 1021 | learning_rate_unet *= args.cycle_decay | 1023 | learning_rate_unet *= args.cycle_decay |
| 1022 | if args.learning_rate_text is not None: | 1024 | if learning_rate_text is not None: |
| 1023 | learning_rate_text *= args.cycle_decay | 1025 | learning_rate_text *= args.cycle_decay |
| 1024 | 1026 | ||
| 1025 | accelerator.end_training() | 1027 | accelerator.end_training() |
diff --git a/train_ti.py b/train_ti.py index 94ddbb6..d931db6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -901,6 +901,8 @@ def main(): | |||
| 901 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
| 902 | accelerator.init_trackers(project) | 902 | accelerator.init_trackers(project) |
| 903 | 903 | ||
| 904 | sample_output_dir = output_dir / project / "samples" | ||
| 905 | |||
| 904 | while True: | 906 | while True: |
| 905 | if training_iter >= args.auto_cycles: | 907 | if training_iter >= args.auto_cycles: |
| 906 | response = input("Run another cycle? [y/n] ") | 908 | response = input("Run another cycle? [y/n] ") |
| @@ -933,8 +935,7 @@ def main(): | |||
| 933 | mid_point=args.lr_mid_point, | 935 | mid_point=args.lr_mid_point, |
| 934 | ) | 936 | ) |
| 935 | 937 | ||
| 936 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" | 938 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" |
| 937 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" | ||
| 938 | 939 | ||
| 939 | trainer( | 940 | trainer( |
| 940 | train_dataloader=datamodule.train_dataloader, | 941 | train_dataloader=datamodule.train_dataloader, |
| @@ -943,6 +944,7 @@ def main(): | |||
| 943 | lr_scheduler=lr_scheduler, | 944 | lr_scheduler=lr_scheduler, |
| 944 | num_train_epochs=num_train_epochs, | 945 | num_train_epochs=num_train_epochs, |
| 945 | global_step_offset=training_iter * num_train_steps, | 946 | global_step_offset=training_iter * num_train_steps, |
| 947 | initial_samples=training_iter == 0, | ||
| 946 | # -- | 948 | # -- |
| 947 | group_labels=["emb"], | 949 | group_labels=["emb"], |
| 948 | checkpoint_output_dir=checkpoint_output_dir, | 950 | checkpoint_output_dir=checkpoint_output_dir, |
| @@ -953,7 +955,7 @@ def main(): | |||
| 953 | ) | 955 | ) |
| 954 | 956 | ||
| 955 | training_iter += 1 | 957 | training_iter += 1 |
| 956 | if args.learning_rate is not None: | 958 | if learning_rate is not None: |
| 957 | learning_rate *= args.cycle_decay | 959 | learning_rate *= args.cycle_decay |
| 958 | 960 | ||
| 959 | accelerator.end_training() | 961 | accelerator.end_training() |
diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -451,6 +451,7 @@ def train_loop( | |||
| 451 | sample_frequency: int = 10, | 451 | sample_frequency: int = 10, |
| 452 | checkpoint_frequency: int = 50, | 452 | checkpoint_frequency: int = 50, |
| 453 | milestone_checkpoints: bool = True, | 453 | milestone_checkpoints: bool = True, |
| 454 | initial_samples: bool = True, | ||
| 454 | global_step_offset: int = 0, | 455 | global_step_offset: int = 0, |
| 455 | num_epochs: int = 100, | 456 | num_epochs: int = 100, |
| 456 | gradient_accumulation_steps: int = 1, | 457 | gradient_accumulation_steps: int = 1, |
| @@ -513,7 +514,7 @@ def train_loop( | |||
| 513 | try: | 514 | try: |
| 514 | for epoch in range(num_epochs): | 515 | for epoch in range(num_epochs): |
| 515 | if accelerator.is_main_process: | 516 | if accelerator.is_main_process: |
| 516 | if epoch % sample_frequency == 0: | 517 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): |
| 517 | local_progress_bar.clear() | 518 | local_progress_bar.clear() |
| 518 | global_progress_bar.clear() | 519 | global_progress_bar.clear() |
| 519 | 520 | ||
| @@ -673,6 +674,7 @@ def train( | |||
| 673 | sample_frequency: int = 20, | 674 | sample_frequency: int = 20, |
| 674 | checkpoint_frequency: int = 50, | 675 | checkpoint_frequency: int = 50, |
| 675 | milestone_checkpoints: bool = True, | 676 | milestone_checkpoints: bool = True, |
| 677 | initial_samples: bool = True, | ||
| 676 | global_step_offset: int = 0, | 678 | global_step_offset: int = 0, |
| 677 | guidance_scale: float = 0.0, | 679 | guidance_scale: float = 0.0, |
| 678 | prior_loss_weight: float = 1.0, | 680 | prior_loss_weight: float = 1.0, |
| @@ -723,6 +725,7 @@ def train( | |||
| 723 | sample_frequency=sample_frequency, | 725 | sample_frequency=sample_frequency, |
| 724 | checkpoint_frequency=checkpoint_frequency, | 726 | checkpoint_frequency=checkpoint_frequency, |
| 725 | milestone_checkpoints=milestone_checkpoints, | 727 | milestone_checkpoints=milestone_checkpoints, |
| 728 | initial_samples=initial_samples, | ||
| 726 | global_step_offset=global_step_offset, | 729 | global_step_offset=global_step_offset, |
| 727 | num_epochs=num_train_epochs, | 730 | num_epochs=num_train_epochs, |
| 728 | gradient_accumulation_steps=gradient_accumulation_steps, | 731 | gradient_accumulation_steps=gradient_accumulation_steps, |
