diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index 073e939..b8c7396 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -945,6 +945,8 @@ def main(): | |||
| 945 | if accelerator.is_main_process: | 945 | if accelerator.is_main_process: |
| 946 | accelerator.init_trackers(lora_project) | 946 | accelerator.init_trackers(lora_project) |
| 947 | 947 | ||
| 948 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
| 949 | |||
| 948 | while True: | 950 | while True: |
| 949 | if training_iter >= args.auto_cycles: | 951 | if training_iter >= args.auto_cycles: |
| 950 | response = input("Run another cycle? [y/n] ") | 952 | response = input("Run another cycle? [y/n] ") |
| @@ -995,8 +997,7 @@ def main(): | |||
| 995 | train_epochs=num_train_epochs, | 997 | train_epochs=num_train_epochs, |
| 996 | ) | 998 | ) |
| 997 | 999 | ||
| 998 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" | 1000 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
| 999 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" | ||
| 1000 | 1001 | ||
| 1001 | trainer( | 1002 | trainer( |
| 1002 | strategy=lora_strategy, | 1003 | strategy=lora_strategy, |
| @@ -1007,6 +1008,7 @@ def main(): | |||
| 1007 | num_train_epochs=num_train_epochs, | 1008 | num_train_epochs=num_train_epochs, |
| 1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1009 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1009 | global_step_offset=training_iter * num_train_steps, | 1010 | global_step_offset=training_iter * num_train_steps, |
| 1011 | initial_samples=training_iter == 0, | ||
| 1010 | # -- | 1012 | # -- |
| 1011 | group_labels=group_labels, | 1013 | group_labels=group_labels, |
| 1012 | sample_output_dir=lora_sample_output_dir, | 1014 | sample_output_dir=lora_sample_output_dir, |
| @@ -1015,11 +1017,11 @@ def main(): | |||
| 1015 | ) | 1017 | ) |
| 1016 | 1018 | ||
| 1017 | training_iter += 1 | 1019 | training_iter += 1 |
| 1018 | if args.learning_rate_emb is not None: | 1020 | if learning_rate_emb is not None: |
| 1019 | learning_rate_emb *= args.cycle_decay | 1021 | learning_rate_emb *= args.cycle_decay |
| 1020 | if args.learning_rate_unet is not None: | 1022 | if learning_rate_unet is not None: |
| 1021 | learning_rate_unet *= args.cycle_decay | 1023 | learning_rate_unet *= args.cycle_decay |
| 1022 | if args.learning_rate_text is not None: | 1024 | if learning_rate_text is not None: |
| 1023 | learning_rate_text *= args.cycle_decay | 1025 | learning_rate_text *= args.cycle_decay |
| 1024 | 1026 | ||
| 1025 | accelerator.end_training() | 1027 | accelerator.end_training() |
