diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index 073e939..b8c7396 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -945,6 +945,8 @@ def main(): | |||
945 | if accelerator.is_main_process: | 945 | if accelerator.is_main_process: |
946 | accelerator.init_trackers(lora_project) | 946 | accelerator.init_trackers(lora_project) |
947 | 947 | ||
948 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
949 | |||
948 | while True: | 950 | while True: |
949 | if training_iter >= args.auto_cycles: | 951 | if training_iter >= args.auto_cycles: |
950 | response = input("Run another cycle? [y/n] ") | 952 | response = input("Run another cycle? [y/n] ") |
@@ -995,8 +997,7 @@ def main(): | |||
995 | train_epochs=num_train_epochs, | 997 | train_epochs=num_train_epochs, |
996 | ) | 998 | ) |
997 | 999 | ||
998 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" | 1000 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
999 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" | ||
1000 | 1001 | ||
1001 | trainer( | 1002 | trainer( |
1002 | strategy=lora_strategy, | 1003 | strategy=lora_strategy, |
@@ -1007,6 +1008,7 @@ def main(): | |||
1007 | num_train_epochs=num_train_epochs, | 1008 | num_train_epochs=num_train_epochs, |
1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1009 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1009 | global_step_offset=training_iter * num_train_steps, | 1010 | global_step_offset=training_iter * num_train_steps, |
1011 | initial_samples=training_iter == 0, | ||
1010 | # -- | 1012 | # -- |
1011 | group_labels=group_labels, | 1013 | group_labels=group_labels, |
1012 | sample_output_dir=lora_sample_output_dir, | 1014 | sample_output_dir=lora_sample_output_dir, |
@@ -1015,11 +1017,11 @@ def main(): | |||
1015 | ) | 1017 | ) |
1016 | 1018 | ||
1017 | training_iter += 1 | 1019 | training_iter += 1 |
1018 | if args.learning_rate_emb is not None: | 1020 | if learning_rate_emb is not None: |
1019 | learning_rate_emb *= args.cycle_decay | 1021 | learning_rate_emb *= args.cycle_decay |
1020 | if args.learning_rate_unet is not None: | 1022 | if learning_rate_unet is not None: |
1021 | learning_rate_unet *= args.cycle_decay | 1023 | learning_rate_unet *= args.cycle_decay |
1022 | if args.learning_rate_text is not None: | 1024 | if learning_rate_text is not None: |
1023 | learning_rate_text *= args.cycle_decay | 1025 | learning_rate_text *= args.cycle_decay |
1024 | 1026 | ||
1025 | accelerator.end_training() | 1027 | accelerator.end_training() |