diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
| commit | a7dc66ae0974886a6c6a4c50def1b733bc04525a (patch) | |
| tree | bbea49b82f8f87b0ce6141114875d6253c75d8ab | |
| parent | Randomize dataset across cycles (diff) | |
| download | textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.gz textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.bz2 textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.zip | |
Update
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_lora.py | 23 | ||||
| -rw-r--r-- | train_ti.py | 23 | ||||
| -rw-r--r-- | training/functional.py | 22 |
4 files changed, 43 insertions, 27 deletions
diff --git a/environment.yaml b/environment.yaml index 418cb22..a95df2a 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -11,7 +11,7 @@ dependencies: | |||
| 11 | - python=3.10.8 | 11 | - python=3.10.8 |
| 12 | - pytorch=2.0.0=*cuda11.8* | 12 | - pytorch=2.0.0=*cuda11.8* |
| 13 | - torchvision=0.15.0 | 13 | - torchvision=0.15.0 |
| 14 | - xformers=0.0.18.dev498 | 14 | - xformers=0.0.18.dev504 |
| 15 | - pip: | 15 | - pip: |
| 16 | - -e . | 16 | - -e . |
| 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
diff --git a/train_lora.py b/train_lora.py index 0d8ee23..29e40b2 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -919,6 +919,8 @@ def main(): | |||
| 919 | args.num_train_steps / len(lora_datamodule.train_dataset) | 919 | args.num_train_steps / len(lora_datamodule.train_dataset) |
| 920 | ) * args.gradient_accumulation_steps | 920 | ) * args.gradient_accumulation_steps |
| 921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
| 922 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
| 923 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 922 | if args.sample_num is not None: | 924 | if args.sample_num is not None: |
| 923 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 925 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 924 | 926 | ||
| @@ -956,15 +958,19 @@ def main(): | |||
| 956 | 958 | ||
| 957 | training_iter = 0 | 959 | training_iter = 0 |
| 958 | 960 | ||
| 961 | lora_project = "lora" | ||
| 962 | |||
| 963 | if accelerator.is_main_process: | ||
| 964 | accelerator.init_trackers(lora_project) | ||
| 965 | |||
| 959 | while True: | 966 | while True: |
| 960 | training_iter += 1 | 967 | if training_iter >= args.auto_cycles: |
| 961 | if training_iter > args.auto_cycles: | ||
| 962 | response = input("Run another cycle? [y/n] ") | 968 | response = input("Run another cycle? [y/n] ") |
| 963 | if response.lower().strip() == "n": | 969 | if response.lower().strip() == "n": |
| 964 | break | 970 | break |
| 965 | 971 | ||
| 966 | print("") | 972 | print("") |
| 967 | print(f"============ LoRA cycle {training_iter} ============") | 973 | print(f"============ LoRA cycle {training_iter + 1} ============") |
| 968 | print("") | 974 | print("") |
| 969 | 975 | ||
| 970 | lora_optimizer = create_optimizer(params_to_optimize) | 976 | lora_optimizer = create_optimizer(params_to_optimize) |
| @@ -976,19 +982,18 @@ def main(): | |||
| 976 | train_epochs=num_train_epochs, | 982 | train_epochs=num_train_epochs, |
| 977 | ) | 983 | ) |
| 978 | 984 | ||
| 979 | lora_project = f"lora_{training_iter}" | 985 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" |
| 980 | lora_checkpoint_output_dir = output_dir / lora_project / "model" | 986 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" |
| 981 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
| 982 | 987 | ||
| 983 | trainer( | 988 | trainer( |
| 984 | strategy=lora_strategy, | 989 | strategy=lora_strategy, |
| 985 | project=lora_project, | ||
| 986 | train_dataloader=lora_datamodule.train_dataloader, | 990 | train_dataloader=lora_datamodule.train_dataloader, |
| 987 | val_dataloader=lora_datamodule.val_dataloader, | 991 | val_dataloader=lora_datamodule.val_dataloader, |
| 988 | optimizer=lora_optimizer, | 992 | optimizer=lora_optimizer, |
| 989 | lr_scheduler=lora_lr_scheduler, | 993 | lr_scheduler=lora_lr_scheduler, |
| 990 | num_train_epochs=num_train_epochs, | 994 | num_train_epochs=num_train_epochs, |
| 991 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 995 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 996 | global_step_offset=training_iter * num_train_steps, | ||
| 992 | # -- | 997 | # -- |
| 993 | group_labels=group_labels, | 998 | group_labels=group_labels, |
| 994 | sample_output_dir=lora_sample_output_dir, | 999 | sample_output_dir=lora_sample_output_dir, |
| @@ -996,6 +1001,10 @@ def main(): | |||
| 996 | sample_frequency=lora_sample_frequency, | 1001 | sample_frequency=lora_sample_frequency, |
| 997 | ) | 1002 | ) |
| 998 | 1003 | ||
| 1004 | training_iter += 1 | ||
| 1005 | |||
| 1006 | accelerator.end_training() | ||
| 1007 | |||
| 999 | 1008 | ||
| 1000 | if __name__ == "__main__": | 1009 | if __name__ == "__main__": |
| 1001 | main() | 1010 | main() |
diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -870,20 +870,26 @@ def main(): | |||
| 870 | args.num_train_steps / len(datamodule.train_dataset) | 870 | args.num_train_steps / len(datamodule.train_dataset) |
| 871 | ) * args.gradient_accumulation_steps | 871 | ) * args.gradient_accumulation_steps |
| 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 873 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
| 874 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 873 | if args.sample_num is not None: | 875 | if args.sample_num is not None: |
| 874 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 876 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 875 | 877 | ||
| 876 | training_iter = 0 | 878 | training_iter = 0 |
| 877 | 879 | ||
| 880 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | ||
| 881 | |||
| 882 | if accelerator.is_main_process: | ||
| 883 | accelerator.init_trackers(project) | ||
| 884 | |||
| 878 | while True: | 885 | while True: |
| 879 | training_iter += 1 | 886 | if training_iter >= args.auto_cycles: |
| 880 | if training_iter > args.auto_cycles: | ||
| 881 | response = input("Run another cycle? [y/n] ") | 887 | response = input("Run another cycle? [y/n] ") |
| 882 | if response.lower().strip() == "n": | 888 | if response.lower().strip() == "n": |
| 883 | break | 889 | break |
| 884 | 890 | ||
| 885 | print("") | 891 | print("") |
| 886 | print(f"------------ TI cycle {training_iter} ------------") | 892 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 887 | print("") | 893 | print("") |
| 888 | 894 | ||
| 889 | optimizer = create_optimizer( | 895 | optimizer = create_optimizer( |
| @@ -908,17 +914,16 @@ def main(): | |||
| 908 | mid_point=args.lr_mid_point, | 914 | mid_point=args.lr_mid_point, |
| 909 | ) | 915 | ) |
| 910 | 916 | ||
| 911 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 917 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" |
| 912 | sample_output_dir = output_dir / project / "samples" | 918 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" |
| 913 | checkpoint_output_dir = output_dir / project / "checkpoints" | ||
| 914 | 919 | ||
| 915 | trainer( | 920 | trainer( |
| 916 | project=project, | ||
| 917 | train_dataloader=datamodule.train_dataloader, | 921 | train_dataloader=datamodule.train_dataloader, |
| 918 | val_dataloader=datamodule.val_dataloader, | 922 | val_dataloader=datamodule.val_dataloader, |
| 919 | optimizer=optimizer, | 923 | optimizer=optimizer, |
| 920 | lr_scheduler=lr_scheduler, | 924 | lr_scheduler=lr_scheduler, |
| 921 | num_train_epochs=num_train_epochs, | 925 | num_train_epochs=num_train_epochs, |
| 926 | global_step_offset=training_iter * num_train_steps, | ||
| 922 | # -- | 927 | # -- |
| 923 | group_labels=["emb"], | 928 | group_labels=["emb"], |
| 924 | checkpoint_output_dir=checkpoint_output_dir, | 929 | checkpoint_output_dir=checkpoint_output_dir, |
| @@ -928,6 +933,10 @@ def main(): | |||
| 928 | placeholder_token_ids=placeholder_token_ids, | 933 | placeholder_token_ids=placeholder_token_ids, |
| 929 | ) | 934 | ) |
| 930 | 935 | ||
| 936 | training_iter += 1 | ||
| 937 | |||
| 938 | accelerator.end_training() | ||
| 939 | |||
| 931 | if not args.sequential: | 940 | if not args.sequential: |
| 932 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 941 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 933 | else: | 942 | else: |
diff --git a/training/functional.py b/training/functional.py index 4220c79..2dcfbb8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -476,6 +476,9 @@ def train_loop( | |||
| 476 | except ImportError: | 476 | except ImportError: |
| 477 | pass | 477 | pass |
| 478 | 478 | ||
| 479 | num_training_steps += global_step_offset | ||
| 480 | global_step += global_step_offset | ||
| 481 | |||
| 479 | try: | 482 | try: |
| 480 | for epoch in range(num_epochs): | 483 | for epoch in range(num_epochs): |
| 481 | if accelerator.is_main_process: | 484 | if accelerator.is_main_process: |
| @@ -484,13 +487,13 @@ def train_loop( | |||
| 484 | global_progress_bar.clear() | 487 | global_progress_bar.clear() |
| 485 | 488 | ||
| 486 | with on_eval(): | 489 | with on_eval(): |
| 487 | on_sample(global_step + global_step_offset) | 490 | on_sample(global_step) |
| 488 | 491 | ||
| 489 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 492 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 490 | local_progress_bar.clear() | 493 | local_progress_bar.clear() |
| 491 | global_progress_bar.clear() | 494 | global_progress_bar.clear() |
| 492 | 495 | ||
| 493 | on_checkpoint(global_step + global_step_offset, "training") | 496 | on_checkpoint(global_step, "training") |
| 494 | 497 | ||
| 495 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 496 | local_progress_bar.reset() | 499 | local_progress_bar.reset() |
| @@ -592,7 +595,7 @@ def train_loop( | |||
| 592 | 595 | ||
| 593 | accelerator.print( | 596 | accelerator.print( |
| 594 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 597 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 595 | on_checkpoint(global_step + global_step_offset, "milestone") | 598 | on_checkpoint(global_step, "milestone") |
| 596 | best_acc_val = avg_acc_val.avg.item() | 599 | best_acc_val = avg_acc_val.avg.item() |
| 597 | else: | 600 | else: |
| 598 | if accelerator.is_main_process: | 601 | if accelerator.is_main_process: |
| @@ -602,20 +605,20 @@ def train_loop( | |||
| 602 | 605 | ||
| 603 | accelerator.print( | 606 | accelerator.print( |
| 604 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 607 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
| 605 | on_checkpoint(global_step + global_step_offset, "milestone") | 608 | on_checkpoint(global_step, "milestone") |
| 606 | best_acc = avg_acc.avg.item() | 609 | best_acc = avg_acc.avg.item() |
| 607 | 610 | ||
| 608 | # Create the pipeline using using the trained modules and save it. | 611 | # Create the pipeline using using the trained modules and save it. |
| 609 | if accelerator.is_main_process: | 612 | if accelerator.is_main_process: |
| 610 | print("Finished!") | 613 | print("Finished!") |
| 611 | with on_eval(): | 614 | with on_eval(): |
| 612 | on_sample(global_step + global_step_offset) | 615 | on_sample(global_step) |
| 613 | on_checkpoint(global_step + global_step_offset, "end") | 616 | on_checkpoint(global_step, "end") |
| 614 | 617 | ||
| 615 | except KeyboardInterrupt: | 618 | except KeyboardInterrupt: |
| 616 | if accelerator.is_main_process: | 619 | if accelerator.is_main_process: |
| 617 | print("Interrupted") | 620 | print("Interrupted") |
| 618 | on_checkpoint(global_step + global_step_offset, "end") | 621 | on_checkpoint(global_step, "end") |
| 619 | raise KeyboardInterrupt | 622 | raise KeyboardInterrupt |
| 620 | 623 | ||
| 621 | 624 | ||
| @@ -627,7 +630,6 @@ def train( | |||
| 627 | noise_scheduler: SchedulerMixin, | 630 | noise_scheduler: SchedulerMixin, |
| 628 | dtype: torch.dtype, | 631 | dtype: torch.dtype, |
| 629 | seed: int, | 632 | seed: int, |
| 630 | project: str, | ||
| 631 | train_dataloader: DataLoader, | 633 | train_dataloader: DataLoader, |
| 632 | val_dataloader: Optional[DataLoader], | 634 | val_dataloader: Optional[DataLoader], |
| 633 | optimizer: torch.optim.Optimizer, | 635 | optimizer: torch.optim.Optimizer, |
| @@ -678,9 +680,6 @@ def train( | |||
| 678 | min_snr_gamma, | 680 | min_snr_gamma, |
| 679 | ) | 681 | ) |
| 680 | 682 | ||
| 681 | if accelerator.is_main_process: | ||
| 682 | accelerator.init_trackers(project) | ||
| 683 | |||
| 684 | train_loop( | 683 | train_loop( |
| 685 | accelerator=accelerator, | 684 | accelerator=accelerator, |
| 686 | optimizer=optimizer, | 685 | optimizer=optimizer, |
| @@ -701,5 +700,4 @@ def train( | |||
| 701 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 700 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 702 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 701 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 703 | 702 | ||
| 704 | accelerator.end_training() | ||
| 705 | accelerator.free_memory() | 703 | accelerator.free_memory() |
