summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py
index 073e939..b8c7396 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -945,6 +945,8 @@ def main():
945 if accelerator.is_main_process: 945 if accelerator.is_main_process:
946 accelerator.init_trackers(lora_project) 946 accelerator.init_trackers(lora_project)
947 947
948 lora_sample_output_dir = output_dir / lora_project / "samples"
949
948 while True: 950 while True:
949 if training_iter >= args.auto_cycles: 951 if training_iter >= args.auto_cycles:
950 response = input("Run another cycle? [y/n] ") 952 response = input("Run another cycle? [y/n] ")
@@ -995,8 +997,7 @@ def main():
995 train_epochs=num_train_epochs, 997 train_epochs=num_train_epochs,
996 ) 998 )
997 999
998 lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" 1000 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}"
999 lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples"
1000 1001
1001 trainer( 1002 trainer(
1002 strategy=lora_strategy, 1003 strategy=lora_strategy,
@@ -1007,6 +1008,7 @@ def main():
1007 num_train_epochs=num_train_epochs, 1008 num_train_epochs=num_train_epochs,
1008 gradient_accumulation_steps=args.gradient_accumulation_steps, 1009 gradient_accumulation_steps=args.gradient_accumulation_steps,
1009 global_step_offset=training_iter * num_train_steps, 1010 global_step_offset=training_iter * num_train_steps,
1011 initial_samples=training_iter == 0,
1010 # -- 1012 # --
1011 group_labels=group_labels, 1013 group_labels=group_labels,
1012 sample_output_dir=lora_sample_output_dir, 1014 sample_output_dir=lora_sample_output_dir,
@@ -1015,11 +1017,11 @@ def main():
1015 ) 1017 )
1016 1018
1017 training_iter += 1 1019 training_iter += 1
1018 if args.learning_rate_emb is not None: 1020 if learning_rate_emb is not None:
1019 learning_rate_emb *= args.cycle_decay 1021 learning_rate_emb *= args.cycle_decay
1020 if args.learning_rate_unet is not None: 1022 if learning_rate_unet is not None:
1021 learning_rate_unet *= args.cycle_decay 1023 learning_rate_unet *= args.cycle_decay
1022 if args.learning_rate_text is not None: 1024 if learning_rate_text is not None:
1023 learning_rate_text *= args.cycle_decay 1025 learning_rate_text *= args.cycle_decay
1024 1026
1025 accelerator.end_training() 1027 accelerator.end_training()