diff options
-rw-r--r-- | environment.yaml | 2 | ||||
-rw-r--r-- | train_lora.py | 23 | ||||
-rw-r--r-- | train_ti.py | 23 | ||||
-rw-r--r-- | training/functional.py | 22 |
4 files changed, 43 insertions, 27 deletions
diff --git a/environment.yaml b/environment.yaml index 418cb22..a95df2a 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -11,7 +11,7 @@ dependencies: | |||
11 | - python=3.10.8 | 11 | - python=3.10.8 |
12 | - pytorch=2.0.0=*cuda11.8* | 12 | - pytorch=2.0.0=*cuda11.8* |
13 | - torchvision=0.15.0 | 13 | - torchvision=0.15.0 |
14 | - xformers=0.0.18.dev498 | 14 | - xformers=0.0.18.dev504 |
15 | - pip: | 15 | - pip: |
16 | - -e . | 16 | - -e . |
17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
diff --git a/train_lora.py b/train_lora.py index 0d8ee23..29e40b2 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -919,6 +919,8 @@ def main(): | |||
919 | args.num_train_steps / len(lora_datamodule.train_dataset) | 919 | args.num_train_steps / len(lora_datamodule.train_dataset) |
920 | ) * args.gradient_accumulation_steps | 920 | ) * args.gradient_accumulation_steps |
921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
922 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
923 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
922 | if args.sample_num is not None: | 924 | if args.sample_num is not None: |
923 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 925 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
924 | 926 | ||
@@ -956,15 +958,19 @@ def main(): | |||
956 | 958 | ||
957 | training_iter = 0 | 959 | training_iter = 0 |
958 | 960 | ||
961 | lora_project = "lora" | ||
962 | |||
963 | if accelerator.is_main_process: | ||
964 | accelerator.init_trackers(lora_project) | ||
965 | |||
959 | while True: | 966 | while True: |
960 | training_iter += 1 | 967 | if training_iter >= args.auto_cycles: |
961 | if training_iter > args.auto_cycles: | ||
962 | response = input("Run another cycle? [y/n] ") | 968 | response = input("Run another cycle? [y/n] ") |
963 | if response.lower().strip() == "n": | 969 | if response.lower().strip() == "n": |
964 | break | 970 | break |
965 | 971 | ||
966 | print("") | 972 | print("") |
967 | print(f"============ LoRA cycle {training_iter} ============") | 973 | print(f"============ LoRA cycle {training_iter + 1} ============") |
968 | print("") | 974 | print("") |
969 | 975 | ||
970 | lora_optimizer = create_optimizer(params_to_optimize) | 976 | lora_optimizer = create_optimizer(params_to_optimize) |
@@ -976,19 +982,18 @@ def main(): | |||
976 | train_epochs=num_train_epochs, | 982 | train_epochs=num_train_epochs, |
977 | ) | 983 | ) |
978 | 984 | ||
979 | lora_project = f"lora_{training_iter}" | 985 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" |
980 | lora_checkpoint_output_dir = output_dir / lora_project / "model" | 986 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" |
981 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
982 | 987 | ||
983 | trainer( | 988 | trainer( |
984 | strategy=lora_strategy, | 989 | strategy=lora_strategy, |
985 | project=lora_project, | ||
986 | train_dataloader=lora_datamodule.train_dataloader, | 990 | train_dataloader=lora_datamodule.train_dataloader, |
987 | val_dataloader=lora_datamodule.val_dataloader, | 991 | val_dataloader=lora_datamodule.val_dataloader, |
988 | optimizer=lora_optimizer, | 992 | optimizer=lora_optimizer, |
989 | lr_scheduler=lora_lr_scheduler, | 993 | lr_scheduler=lora_lr_scheduler, |
990 | num_train_epochs=num_train_epochs, | 994 | num_train_epochs=num_train_epochs, |
991 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 995 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
996 | global_step_offset=training_iter * num_train_steps, | ||
992 | # -- | 997 | # -- |
993 | group_labels=group_labels, | 998 | group_labels=group_labels, |
994 | sample_output_dir=lora_sample_output_dir, | 999 | sample_output_dir=lora_sample_output_dir, |
@@ -996,6 +1001,10 @@ def main(): | |||
996 | sample_frequency=lora_sample_frequency, | 1001 | sample_frequency=lora_sample_frequency, |
997 | ) | 1002 | ) |
998 | 1003 | ||
1004 | training_iter += 1 | ||
1005 | |||
1006 | accelerator.end_training() | ||
1007 | |||
999 | 1008 | ||
1000 | if __name__ == "__main__": | 1009 | if __name__ == "__main__": |
1001 | main() | 1010 | main() |
diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -870,20 +870,26 @@ def main(): | |||
870 | args.num_train_steps / len(datamodule.train_dataset) | 870 | args.num_train_steps / len(datamodule.train_dataset) |
871 | ) * args.gradient_accumulation_steps | 871 | ) * args.gradient_accumulation_steps |
872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
873 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
874 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
873 | if args.sample_num is not None: | 875 | if args.sample_num is not None: |
874 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 876 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
875 | 877 | ||
876 | training_iter = 0 | 878 | training_iter = 0 |
877 | 879 | ||
880 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | ||
881 | |||
882 | if accelerator.is_main_process: | ||
883 | accelerator.init_trackers(project) | ||
884 | |||
878 | while True: | 885 | while True: |
879 | training_iter += 1 | 886 | if training_iter >= args.auto_cycles: |
880 | if training_iter > args.auto_cycles: | ||
881 | response = input("Run another cycle? [y/n] ") | 887 | response = input("Run another cycle? [y/n] ") |
882 | if response.lower().strip() == "n": | 888 | if response.lower().strip() == "n": |
883 | break | 889 | break |
884 | 890 | ||
885 | print("") | 891 | print("") |
886 | print(f"------------ TI cycle {training_iter} ------------") | 892 | print(f"------------ TI cycle {training_iter + 1} ------------") |
887 | print("") | 893 | print("") |
888 | 894 | ||
889 | optimizer = create_optimizer( | 895 | optimizer = create_optimizer( |
@@ -908,17 +914,16 @@ def main(): | |||
908 | mid_point=args.lr_mid_point, | 914 | mid_point=args.lr_mid_point, |
909 | ) | 915 | ) |
910 | 916 | ||
911 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 917 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" |
912 | sample_output_dir = output_dir / project / "samples" | 918 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" |
913 | checkpoint_output_dir = output_dir / project / "checkpoints" | ||
914 | 919 | ||
915 | trainer( | 920 | trainer( |
916 | project=project, | ||
917 | train_dataloader=datamodule.train_dataloader, | 921 | train_dataloader=datamodule.train_dataloader, |
918 | val_dataloader=datamodule.val_dataloader, | 922 | val_dataloader=datamodule.val_dataloader, |
919 | optimizer=optimizer, | 923 | optimizer=optimizer, |
920 | lr_scheduler=lr_scheduler, | 924 | lr_scheduler=lr_scheduler, |
921 | num_train_epochs=num_train_epochs, | 925 | num_train_epochs=num_train_epochs, |
926 | global_step_offset=training_iter * num_train_steps, | ||
922 | # -- | 927 | # -- |
923 | group_labels=["emb"], | 928 | group_labels=["emb"], |
924 | checkpoint_output_dir=checkpoint_output_dir, | 929 | checkpoint_output_dir=checkpoint_output_dir, |
@@ -928,6 +933,10 @@ def main(): | |||
928 | placeholder_token_ids=placeholder_token_ids, | 933 | placeholder_token_ids=placeholder_token_ids, |
929 | ) | 934 | ) |
930 | 935 | ||
936 | training_iter += 1 | ||
937 | |||
938 | accelerator.end_training() | ||
939 | |||
931 | if not args.sequential: | 940 | if not args.sequential: |
932 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 941 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
933 | else: | 942 | else: |
diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -476,6 +476,9 @@ def train_loop( | |||
476 | except ImportError: | 476 | except ImportError: |
477 | pass | 477 | pass |
478 | 478 | ||
479 | num_training_steps += global_step_offset | ||
480 | global_step += global_step_offset | ||
481 | |||
479 | try: | 482 | try: |
480 | for epoch in range(num_epochs): | 483 | for epoch in range(num_epochs): |
481 | if accelerator.is_main_process: | 484 | if accelerator.is_main_process: |
@@ -484,13 +487,13 @@ def train_loop( | |||
484 | global_progress_bar.clear() | 487 | global_progress_bar.clear() |
485 | 488 | ||
486 | with on_eval(): | 489 | with on_eval(): |
487 | on_sample(global_step + global_step_offset) | 490 | on_sample(global_step) |
488 | 491 | ||
489 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 492 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
490 | local_progress_bar.clear() | 493 | local_progress_bar.clear() |
491 | global_progress_bar.clear() | 494 | global_progress_bar.clear() |
492 | 495 | ||
493 | on_checkpoint(global_step + global_step_offset, "training") | 496 | on_checkpoint(global_step, "training") |
494 | 497 | ||
495 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
496 | local_progress_bar.reset() | 499 | local_progress_bar.reset() |
@@ -592,7 +595,7 @@ def train_loop( | |||
592 | 595 | ||
593 | accelerator.print( | 596 | accelerator.print( |
594 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
595 | on_checkpoint(global_step + global_step_offset, "milestone") | 598 | on_checkpoint(global_step, "milestone") |
596 | best_acc_val = avg_acc_val.avg.item() | 599 | best_acc_val = avg_acc_val.avg.item() |
597 | else: | 600 | else: |
598 | if accelerator.is_main_process: | 601 | if accelerator.is_main_process: |
@@ -602,20 +605,20 @@ def train_loop( | |||
602 | 605 | ||
603 | accelerator.print( | 606 | accelerator.print( |
604 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 607 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
605 | on_checkpoint(global_step + global_step_offset, "milestone") | 608 | on_checkpoint(global_step, "milestone") |
606 | best_acc = avg_acc.avg.item() | 609 | best_acc = avg_acc.avg.item() |
607 | 610 | ||
608 | # Create the pipeline using using the trained modules and save it. | 611 | # Create the pipeline using using the trained modules and save it. |
609 | if accelerator.is_main_process: | 612 | if accelerator.is_main_process: |
610 | print("Finished!") | 613 | print("Finished!") |
611 | with on_eval(): | 614 | with on_eval(): |
612 | on_sample(global_step + global_step_offset) | 615 | on_sample(global_step) |
613 | on_checkpoint(global_step + global_step_offset, "end") | 616 | on_checkpoint(global_step, "end") |
614 | 617 | ||
615 | except KeyboardInterrupt: | 618 | except KeyboardInterrupt: |
616 | if accelerator.is_main_process: | 619 | if accelerator.is_main_process: |
617 | print("Interrupted") | 620 | print("Interrupted") |
618 | on_checkpoint(global_step + global_step_offset, "end") | 621 | on_checkpoint(global_step, "end") |
619 | raise KeyboardInterrupt | 622 | raise KeyboardInterrupt |
620 | 623 | ||
621 | 624 | ||
@@ -627,7 +630,6 @@ def train( | |||
627 | noise_scheduler: SchedulerMixin, | 630 | noise_scheduler: SchedulerMixin, |
628 | dtype: torch.dtype, | 631 | dtype: torch.dtype, |
629 | seed: int, | 632 | seed: int, |
630 | project: str, | ||
631 | train_dataloader: DataLoader, | 633 | train_dataloader: DataLoader, |
632 | val_dataloader: Optional[DataLoader], | 634 | val_dataloader: Optional[DataLoader], |
633 | optimizer: torch.optim.Optimizer, | 635 | optimizer: torch.optim.Optimizer, |
@@ -678,9 +680,6 @@ def train( | |||
678 | min_snr_gamma, | 680 | min_snr_gamma, |
679 | ) | 681 | ) |
680 | 682 | ||
681 | if accelerator.is_main_process: | ||
682 | accelerator.init_trackers(project) | ||
683 | |||
684 | train_loop( | 683 | train_loop( |
685 | accelerator=accelerator, | 684 | accelerator=accelerator, |
686 | optimizer=optimizer, | 685 | optimizer=optimizer, |
@@ -701,5 +700,4 @@ def train( | |||
701 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 700 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
702 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 701 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
703 | 702 | ||
704 | accelerator.end_training() | ||
705 | accelerator.free_memory() | 703 | accelerator.free_memory() |