summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml2
-rw-r--r--train_lora.py23
-rw-r--r--train_ti.py23
-rw-r--r--training/functional.py22
4 files changed, 43 insertions, 27 deletions
diff --git a/environment.yaml b/environment.yaml
index 418cb22..a95df2a 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,7 +11,7 @@ dependencies:
11 - python=3.10.8 11 - python=3.10.8
12 - pytorch=2.0.0=*cuda11.8* 12 - pytorch=2.0.0=*cuda11.8*
13 - torchvision=0.15.0 13 - torchvision=0.15.0
14 - xformers=0.0.18.dev498 14 - xformers=0.0.18.dev504
15 - pip: 15 - pip:
16 - -e . 16 - -e .
17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
diff --git a/train_lora.py b/train_lora.py
index 0d8ee23..29e40b2 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -919,6 +919,8 @@ def main():
919 args.num_train_steps / len(lora_datamodule.train_dataset) 919 args.num_train_steps / len(lora_datamodule.train_dataset)
920 ) * args.gradient_accumulation_steps 920 ) * args.gradient_accumulation_steps
921 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 921 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
922 num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps)
923 num_train_steps = num_training_steps_per_epoch * num_train_epochs
922 if args.sample_num is not None: 924 if args.sample_num is not None:
923 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 925 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
924 926
@@ -956,15 +958,19 @@ def main():
956 958
957 training_iter = 0 959 training_iter = 0
958 960
961 lora_project = "lora"
962
963 if accelerator.is_main_process:
964 accelerator.init_trackers(lora_project)
965
959 while True: 966 while True:
960 training_iter += 1 967 if training_iter >= args.auto_cycles:
961 if training_iter > args.auto_cycles:
962 response = input("Run another cycle? [y/n] ") 968 response = input("Run another cycle? [y/n] ")
963 if response.lower().strip() == "n": 969 if response.lower().strip() == "n":
964 break 970 break
965 971
966 print("") 972 print("")
967 print(f"============ LoRA cycle {training_iter} ============") 973 print(f"============ LoRA cycle {training_iter + 1} ============")
968 print("") 974 print("")
969 975
970 lora_optimizer = create_optimizer(params_to_optimize) 976 lora_optimizer = create_optimizer(params_to_optimize)
@@ -976,19 +982,18 @@ def main():
976 train_epochs=num_train_epochs, 982 train_epochs=num_train_epochs,
977 ) 983 )
978 984
979 lora_project = f"lora_{training_iter}" 985 lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model"
980 lora_checkpoint_output_dir = output_dir / lora_project / "model" 986 lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples"
981 lora_sample_output_dir = output_dir / lora_project / "samples"
982 987
983 trainer( 988 trainer(
984 strategy=lora_strategy, 989 strategy=lora_strategy,
985 project=lora_project,
986 train_dataloader=lora_datamodule.train_dataloader, 990 train_dataloader=lora_datamodule.train_dataloader,
987 val_dataloader=lora_datamodule.val_dataloader, 991 val_dataloader=lora_datamodule.val_dataloader,
988 optimizer=lora_optimizer, 992 optimizer=lora_optimizer,
989 lr_scheduler=lora_lr_scheduler, 993 lr_scheduler=lora_lr_scheduler,
990 num_train_epochs=num_train_epochs, 994 num_train_epochs=num_train_epochs,
991 gradient_accumulation_steps=args.gradient_accumulation_steps, 995 gradient_accumulation_steps=args.gradient_accumulation_steps,
996 global_step_offset=training_iter * num_train_steps,
992 # -- 997 # --
993 group_labels=group_labels, 998 group_labels=group_labels,
994 sample_output_dir=lora_sample_output_dir, 999 sample_output_dir=lora_sample_output_dir,
@@ -996,6 +1001,10 @@ def main():
996 sample_frequency=lora_sample_frequency, 1001 sample_frequency=lora_sample_frequency,
997 ) 1002 )
998 1003
1004 training_iter += 1
1005
1006 accelerator.end_training()
1007
999 1008
1000if __name__ == "__main__": 1009if __name__ == "__main__":
1001 main() 1010 main()
diff --git a/train_ti.py b/train_ti.py
index 009495b..d7878cd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -870,20 +870,26 @@ def main():
870 args.num_train_steps / len(datamodule.train_dataset) 870 args.num_train_steps / len(datamodule.train_dataset)
871 ) * args.gradient_accumulation_steps 871 ) * args.gradient_accumulation_steps
872 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 872 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
873 num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps)
874 num_train_steps = num_training_steps_per_epoch * num_train_epochs
873 if args.sample_num is not None: 875 if args.sample_num is not None:
874 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 876 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
875 877
876 training_iter = 0 878 training_iter = 0
877 879
880 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
881
882 if accelerator.is_main_process:
883 accelerator.init_trackers(project)
884
878 while True: 885 while True:
879 training_iter += 1 886 if training_iter >= args.auto_cycles:
880 if training_iter > args.auto_cycles:
881 response = input("Run another cycle? [y/n] ") 887 response = input("Run another cycle? [y/n] ")
882 if response.lower().strip() == "n": 888 if response.lower().strip() == "n":
883 break 889 break
884 890
885 print("") 891 print("")
886 print(f"------------ TI cycle {training_iter} ------------") 892 print(f"------------ TI cycle {training_iter + 1} ------------")
887 print("") 893 print("")
888 894
889 optimizer = create_optimizer( 895 optimizer = create_optimizer(
@@ -908,17 +914,16 @@ def main():
908 mid_point=args.lr_mid_point, 914 mid_point=args.lr_mid_point,
909 ) 915 )
910 916
911 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" 917 sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples"
912 sample_output_dir = output_dir / project / "samples" 918 checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints"
913 checkpoint_output_dir = output_dir / project / "checkpoints"
914 919
915 trainer( 920 trainer(
916 project=project,
917 train_dataloader=datamodule.train_dataloader, 921 train_dataloader=datamodule.train_dataloader,
918 val_dataloader=datamodule.val_dataloader, 922 val_dataloader=datamodule.val_dataloader,
919 optimizer=optimizer, 923 optimizer=optimizer,
920 lr_scheduler=lr_scheduler, 924 lr_scheduler=lr_scheduler,
921 num_train_epochs=num_train_epochs, 925 num_train_epochs=num_train_epochs,
926 global_step_offset=training_iter * num_train_steps,
922 # -- 927 # --
923 group_labels=["emb"], 928 group_labels=["emb"],
924 checkpoint_output_dir=checkpoint_output_dir, 929 checkpoint_output_dir=checkpoint_output_dir,
@@ -928,6 +933,10 @@ def main():
928 placeholder_token_ids=placeholder_token_ids, 933 placeholder_token_ids=placeholder_token_ids,
929 ) 934 )
930 935
936 training_iter += 1
937
938 accelerator.end_training()
939
931 if not args.sequential: 940 if not args.sequential:
932 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 941 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
933 else: 942 else:
diff --git a/training/functional.py b/training/functional.py
index 4220c79..2dcfbb8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -476,6 +476,9 @@ def train_loop(
476 except ImportError: 476 except ImportError:
477 pass 477 pass
478 478
479 num_training_steps += global_step_offset
480 global_step += global_step_offset
481
479 try: 482 try:
480 for epoch in range(num_epochs): 483 for epoch in range(num_epochs):
481 if accelerator.is_main_process: 484 if accelerator.is_main_process:
@@ -484,13 +487,13 @@ def train_loop(
484 global_progress_bar.clear() 487 global_progress_bar.clear()
485 488
486 with on_eval(): 489 with on_eval():
487 on_sample(global_step + global_step_offset) 490 on_sample(global_step)
488 491
489 if epoch % checkpoint_frequency == 0 and epoch != 0: 492 if epoch % checkpoint_frequency == 0 and epoch != 0:
490 local_progress_bar.clear() 493 local_progress_bar.clear()
491 global_progress_bar.clear() 494 global_progress_bar.clear()
492 495
493 on_checkpoint(global_step + global_step_offset, "training") 496 on_checkpoint(global_step, "training")
494 497
495 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 498 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
496 local_progress_bar.reset() 499 local_progress_bar.reset()
@@ -592,7 +595,7 @@ def train_loop(
592 595
593 accelerator.print( 596 accelerator.print(
594 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 597 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
595 on_checkpoint(global_step + global_step_offset, "milestone") 598 on_checkpoint(global_step, "milestone")
596 best_acc_val = avg_acc_val.avg.item() 599 best_acc_val = avg_acc_val.avg.item()
597 else: 600 else:
598 if accelerator.is_main_process: 601 if accelerator.is_main_process:
@@ -602,20 +605,20 @@ def train_loop(
602 605
603 accelerator.print( 606 accelerator.print(
604 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 607 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
605 on_checkpoint(global_step + global_step_offset, "milestone") 608 on_checkpoint(global_step, "milestone")
606 best_acc = avg_acc.avg.item() 609 best_acc = avg_acc.avg.item()
607 610
608 # Create the pipeline using using the trained modules and save it. 611 # Create the pipeline using using the trained modules and save it.
609 if accelerator.is_main_process: 612 if accelerator.is_main_process:
610 print("Finished!") 613 print("Finished!")
611 with on_eval(): 614 with on_eval():
612 on_sample(global_step + global_step_offset) 615 on_sample(global_step)
613 on_checkpoint(global_step + global_step_offset, "end") 616 on_checkpoint(global_step, "end")
614 617
615 except KeyboardInterrupt: 618 except KeyboardInterrupt:
616 if accelerator.is_main_process: 619 if accelerator.is_main_process:
617 print("Interrupted") 620 print("Interrupted")
618 on_checkpoint(global_step + global_step_offset, "end") 621 on_checkpoint(global_step, "end")
619 raise KeyboardInterrupt 622 raise KeyboardInterrupt
620 623
621 624
@@ -627,7 +630,6 @@ def train(
627 noise_scheduler: SchedulerMixin, 630 noise_scheduler: SchedulerMixin,
628 dtype: torch.dtype, 631 dtype: torch.dtype,
629 seed: int, 632 seed: int,
630 project: str,
631 train_dataloader: DataLoader, 633 train_dataloader: DataLoader,
632 val_dataloader: Optional[DataLoader], 634 val_dataloader: Optional[DataLoader],
633 optimizer: torch.optim.Optimizer, 635 optimizer: torch.optim.Optimizer,
@@ -678,9 +680,6 @@ def train(
678 min_snr_gamma, 680 min_snr_gamma,
679 ) 681 )
680 682
681 if accelerator.is_main_process:
682 accelerator.init_trackers(project)
683
684 train_loop( 683 train_loop(
685 accelerator=accelerator, 684 accelerator=accelerator,
686 optimizer=optimizer, 685 optimizer=optimizer,
@@ -701,5 +700,4 @@ def train(
701 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 700 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
702 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 701 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
703 702
704 accelerator.end_training()
705 accelerator.free_memory() 703 accelerator.free_memory()