summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py12
-rw-r--r--train_ti.py8
-rw-r--r--training/functional.py5
3 files changed, 16 insertions, 9 deletions
diff --git a/train_lora.py b/train_lora.py
index 073e939..b8c7396 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -945,6 +945,8 @@ def main():
945 if accelerator.is_main_process: 945 if accelerator.is_main_process:
946 accelerator.init_trackers(lora_project) 946 accelerator.init_trackers(lora_project)
947 947
948 lora_sample_output_dir = output_dir / lora_project / "samples"
949
948 while True: 950 while True:
949 if training_iter >= args.auto_cycles: 951 if training_iter >= args.auto_cycles:
950 response = input("Run another cycle? [y/n] ") 952 response = input("Run another cycle? [y/n] ")
@@ -995,8 +997,7 @@ def main():
995 train_epochs=num_train_epochs, 997 train_epochs=num_train_epochs,
996 ) 998 )
997 999
998 lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" 1000 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}"
999 lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples"
1000 1001
1001 trainer( 1002 trainer(
1002 strategy=lora_strategy, 1003 strategy=lora_strategy,
@@ -1007,6 +1008,7 @@ def main():
1007 num_train_epochs=num_train_epochs, 1008 num_train_epochs=num_train_epochs,
1008 gradient_accumulation_steps=args.gradient_accumulation_steps, 1009 gradient_accumulation_steps=args.gradient_accumulation_steps,
1009 global_step_offset=training_iter * num_train_steps, 1010 global_step_offset=training_iter * num_train_steps,
1011 initial_samples=training_iter == 0,
1010 # -- 1012 # --
1011 group_labels=group_labels, 1013 group_labels=group_labels,
1012 sample_output_dir=lora_sample_output_dir, 1014 sample_output_dir=lora_sample_output_dir,
@@ -1015,11 +1017,11 @@ def main():
1015 ) 1017 )
1016 1018
1017 training_iter += 1 1019 training_iter += 1
1018 if args.learning_rate_emb is not None: 1020 if learning_rate_emb is not None:
1019 learning_rate_emb *= args.cycle_decay 1021 learning_rate_emb *= args.cycle_decay
1020 if args.learning_rate_unet is not None: 1022 if learning_rate_unet is not None:
1021 learning_rate_unet *= args.cycle_decay 1023 learning_rate_unet *= args.cycle_decay
1022 if args.learning_rate_text is not None: 1024 if learning_rate_text is not None:
1023 learning_rate_text *= args.cycle_decay 1025 learning_rate_text *= args.cycle_decay
1024 1026
1025 accelerator.end_training() 1027 accelerator.end_training()
diff --git a/train_ti.py b/train_ti.py
index 94ddbb6..d931db6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -901,6 +901,8 @@ def main():
901 if accelerator.is_main_process: 901 if accelerator.is_main_process:
902 accelerator.init_trackers(project) 902 accelerator.init_trackers(project)
903 903
904 sample_output_dir = output_dir / project / "samples"
905
904 while True: 906 while True:
905 if training_iter >= args.auto_cycles: 907 if training_iter >= args.auto_cycles:
906 response = input("Run another cycle? [y/n] ") 908 response = input("Run another cycle? [y/n] ")
@@ -933,8 +935,7 @@ def main():
933 mid_point=args.lr_mid_point, 935 mid_point=args.lr_mid_point,
934 ) 936 )
935 937
936 sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" 938 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}"
937 checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints"
938 939
939 trainer( 940 trainer(
940 train_dataloader=datamodule.train_dataloader, 941 train_dataloader=datamodule.train_dataloader,
@@ -943,6 +944,7 @@ def main():
943 lr_scheduler=lr_scheduler, 944 lr_scheduler=lr_scheduler,
944 num_train_epochs=num_train_epochs, 945 num_train_epochs=num_train_epochs,
945 global_step_offset=training_iter * num_train_steps, 946 global_step_offset=training_iter * num_train_steps,
947 initial_samples=training_iter == 0,
946 # -- 948 # --
947 group_labels=["emb"], 949 group_labels=["emb"],
948 checkpoint_output_dir=checkpoint_output_dir, 950 checkpoint_output_dir=checkpoint_output_dir,
@@ -953,7 +955,7 @@ def main():
953 ) 955 )
954 956
955 training_iter += 1 957 training_iter += 1
956 if args.learning_rate is not None: 958 if learning_rate is not None:
957 learning_rate *= args.cycle_decay 959 learning_rate *= args.cycle_decay
958 960
959 accelerator.end_training() 961 accelerator.end_training()
diff --git a/training/functional.py b/training/functional.py
index ed8ae3a..54bbe78 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -451,6 +451,7 @@ def train_loop(
451 sample_frequency: int = 10, 451 sample_frequency: int = 10,
452 checkpoint_frequency: int = 50, 452 checkpoint_frequency: int = 50,
453 milestone_checkpoints: bool = True, 453 milestone_checkpoints: bool = True,
454 initial_samples: bool = True,
454 global_step_offset: int = 0, 455 global_step_offset: int = 0,
455 num_epochs: int = 100, 456 num_epochs: int = 100,
456 gradient_accumulation_steps: int = 1, 457 gradient_accumulation_steps: int = 1,
@@ -513,7 +514,7 @@ def train_loop(
513 try: 514 try:
514 for epoch in range(num_epochs): 515 for epoch in range(num_epochs):
515 if accelerator.is_main_process: 516 if accelerator.is_main_process:
516 if epoch % sample_frequency == 0: 517 if epoch % sample_frequency == 0 and (initial_samples or epoch != 0):
517 local_progress_bar.clear() 518 local_progress_bar.clear()
518 global_progress_bar.clear() 519 global_progress_bar.clear()
519 520
@@ -673,6 +674,7 @@ def train(
673 sample_frequency: int = 20, 674 sample_frequency: int = 20,
674 checkpoint_frequency: int = 50, 675 checkpoint_frequency: int = 50,
675 milestone_checkpoints: bool = True, 676 milestone_checkpoints: bool = True,
677 initial_samples: bool = True,
676 global_step_offset: int = 0, 678 global_step_offset: int = 0,
677 guidance_scale: float = 0.0, 679 guidance_scale: float = 0.0,
678 prior_loss_weight: float = 1.0, 680 prior_loss_weight: float = 1.0,
@@ -723,6 +725,7 @@ def train(
723 sample_frequency=sample_frequency, 725 sample_frequency=sample_frequency,
724 checkpoint_frequency=checkpoint_frequency, 726 checkpoint_frequency=checkpoint_frequency,
725 milestone_checkpoints=milestone_checkpoints, 727 milestone_checkpoints=milestone_checkpoints,
728 initial_samples=initial_samples,
726 global_step_offset=global_step_offset, 729 global_step_offset=global_step_offset,
727 num_epochs=num_train_epochs, 730 num_epochs=num_train_epochs,
728 gradient_accumulation_steps=gradient_accumulation_steps, 731 gradient_accumulation_steps=gradient_accumulation_steps,