diff options
| -rw-r--r-- | train_lora.py | 66 | ||||
| -rw-r--r-- | train_ti.py | 68 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
6 files changed, 89 insertions, 55 deletions
diff --git a/train_lora.py b/train_lora.py index e81742a..4bbc64e 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -199,6 +199,11 @@ def parse_args(): | |||
| 199 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 199 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 200 | ) | 200 | ) |
| 201 | parser.add_argument( | 201 | parser.add_argument( |
| 202 | "--train_dir_embeddings", | ||
| 203 | action="store_true", | ||
| 204 | help="Train embeddings loaded from embeddings directory.", | ||
| 205 | ) | ||
| 206 | parser.add_argument( | ||
| 202 | "--collection", | 207 | "--collection", |
| 203 | type=str, | 208 | type=str, |
| 204 | nargs='*', | 209 | nargs='*', |
| @@ -440,6 +445,12 @@ def parse_args(): | |||
| 440 | help="How often to save a checkpoint and sample image", | 445 | help="How often to save a checkpoint and sample image", |
| 441 | ) | 446 | ) |
| 442 | parser.add_argument( | 447 | parser.add_argument( |
| 448 | "--sample_num", | ||
| 449 | type=int, | ||
| 450 | default=None, | ||
| 451 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
| 452 | ) | ||
| 453 | parser.add_argument( | ||
| 443 | "--sample_image_size", | 454 | "--sample_image_size", |
| 444 | type=int, | 455 | type=int, |
| 445 | default=768, | 456 | default=768, |
| @@ -681,27 +692,36 @@ def main(): | |||
| 681 | embeddings.persist() | 692 | embeddings.persist() |
| 682 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 693 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
| 683 | 694 | ||
| 695 | placeholder_token_ids = [] | ||
| 696 | |||
| 684 | if args.embeddings_dir is not None: | 697 | if args.embeddings_dir is not None: |
| 685 | embeddings_dir = Path(args.embeddings_dir) | 698 | embeddings_dir = Path(args.embeddings_dir) |
| 686 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 699 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 687 | raise ValueError("--embeddings_dir must point to an existing directory") | 700 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 688 | 701 | ||
| 689 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 702 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 690 | embeddings.persist() | ||
| 691 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 703 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 692 | 704 | ||
| 693 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 705 | if args.train_dir_embeddings: |
| 694 | tokenizer=tokenizer, | 706 | args.placeholder_tokens = added_tokens |
| 695 | embeddings=embeddings, | 707 | placeholder_token_ids = added_ids |
| 696 | placeholder_tokens=args.placeholder_tokens, | 708 | print("Training embeddings from embeddings dir") |
| 697 | initializer_tokens=args.initializer_tokens, | 709 | else: |
| 698 | num_vectors=args.num_vectors, | 710 | embeddings.persist() |
| 699 | initializer_noise=args.initializer_noise, | 711 | |
| 700 | ) | 712 | if not args.train_dir_embeddings: |
| 701 | stats = list(zip( | 713 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 702 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 714 | tokenizer=tokenizer, |
| 703 | )) | 715 | embeddings=embeddings, |
| 704 | print(f"Training embeddings: {stats}") | 716 | placeholder_tokens=args.placeholder_tokens, |
| 717 | initializer_tokens=args.initializer_tokens, | ||
| 718 | num_vectors=args.num_vectors, | ||
| 719 | initializer_noise=args.initializer_noise, | ||
| 720 | ) | ||
| 721 | stats = list(zip( | ||
| 722 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | ||
| 723 | )) | ||
| 724 | print(f"Training embeddings: {stats}") | ||
| 705 | 725 | ||
| 706 | if args.scale_lr: | 726 | if args.scale_lr: |
| 707 | args.learning_rate_unet = ( | 727 | args.learning_rate_unet = ( |
| @@ -897,6 +917,8 @@ def main(): | |||
| 897 | args.num_train_steps / len(lora_datamodule.train_dataset) | 917 | args.num_train_steps / len(lora_datamodule.train_dataset) |
| 898 | ) * args.gradient_accumulation_steps | 918 | ) * args.gradient_accumulation_steps |
| 899 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 919 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
| 920 | if args.sample_num is not None: | ||
| 921 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
| 900 | 922 | ||
| 901 | params_to_optimize = [] | 923 | params_to_optimize = [] |
| 902 | group_labels = [] | 924 | group_labels = [] |
| @@ -930,15 +952,6 @@ def main(): | |||
| 930 | ] | 952 | ] |
| 931 | group_labels += ["unet", "text"] | 953 | group_labels += ["unet", "text"] |
| 932 | 954 | ||
| 933 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 934 | |||
| 935 | lora_lr_scheduler = create_lr_scheduler( | ||
| 936 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 937 | optimizer=lora_optimizer, | ||
| 938 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | ||
| 939 | train_epochs=num_train_epochs, | ||
| 940 | ) | ||
| 941 | |||
| 942 | training_iter = 0 | 955 | training_iter = 0 |
| 943 | 956 | ||
| 944 | while True: | 957 | while True: |
| @@ -952,6 +965,15 @@ def main(): | |||
| 952 | print(f"============ LoRA cycle {training_iter} ============") | 965 | print(f"============ LoRA cycle {training_iter} ============") |
| 953 | print("") | 966 | print("") |
| 954 | 967 | ||
| 968 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 969 | |||
| 970 | lora_lr_scheduler = create_lr_scheduler( | ||
| 971 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 972 | optimizer=lora_optimizer, | ||
| 973 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | ||
| 974 | train_epochs=num_train_epochs, | ||
| 975 | ) | ||
| 976 | |||
| 955 | lora_project = f"lora_{training_iter}" | 977 | lora_project = f"lora_{training_iter}" |
| 956 | lora_checkpoint_output_dir = output_dir / lora_project / "model" | 978 | lora_checkpoint_output_dir = output_dir / lora_project / "model" |
| 957 | lora_sample_output_dir = output_dir / lora_project / "samples" | 979 | lora_sample_output_dir = output_dir / lora_project / "samples" |
diff --git a/train_ti.py b/train_ti.py index ebac302..eb08bda 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -152,6 +152,11 @@ def parse_args(): | |||
| 152 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 152 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 153 | ) | 153 | ) |
| 154 | parser.add_argument( | 154 | parser.add_argument( |
| 155 | "--train_dir_embeddings", | ||
| 156 | action="store_true", | ||
| 157 | help="Train embeddings loaded from embeddings directory.", | ||
| 158 | ) | ||
| 159 | parser.add_argument( | ||
| 155 | "--collection", | 160 | "--collection", |
| 156 | type=str, | 161 | type=str, |
| 157 | nargs='*', | 162 | nargs='*', |
| @@ -404,6 +409,12 @@ def parse_args(): | |||
| 404 | help="If checkpoints are saved on maximum accuracy", | 409 | help="If checkpoints are saved on maximum accuracy", |
| 405 | ) | 410 | ) |
| 406 | parser.add_argument( | 411 | parser.add_argument( |
| 412 | "--sample_num", | ||
| 413 | type=int, | ||
| 414 | default=None, | ||
| 415 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
| 416 | ) | ||
| 417 | parser.add_argument( | ||
| 407 | "--sample_frequency", | 418 | "--sample_frequency", |
| 408 | type=int, | 419 | type=int, |
| 409 | default=1, | 420 | default=1, |
| @@ -669,9 +680,14 @@ def main(): | |||
| 669 | raise ValueError("--embeddings_dir must point to an existing directory") | 680 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 670 | 681 | ||
| 671 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 682 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 672 | embeddings.persist() | ||
| 673 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 683 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 674 | 684 | ||
| 685 | if args.train_dir_embeddings: | ||
| 686 | args.placeholder_tokens = added_tokens | ||
| 687 | print("Training embeddings from embeddings dir") | ||
| 688 | else: | ||
| 689 | embeddings.persist() | ||
| 690 | |||
| 675 | if args.scale_lr: | 691 | if args.scale_lr: |
| 676 | args.learning_rate = ( | 692 | args.learning_rate = ( |
| 677 | args.learning_rate * args.gradient_accumulation_steps * | 693 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -852,28 +868,8 @@ def main(): | |||
| 852 | args.num_train_steps / len(datamodule.train_dataset) | 868 | args.num_train_steps / len(datamodule.train_dataset) |
| 853 | ) * args.gradient_accumulation_steps | 869 | ) * args.gradient_accumulation_steps |
| 854 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 870 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 855 | 871 | if args.sample_num is not None: | |
| 856 | optimizer = create_optimizer( | 872 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 857 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
| 858 | lr=args.learning_rate, | ||
| 859 | ) | ||
| 860 | |||
| 861 | lr_scheduler = get_scheduler( | ||
| 862 | args.lr_scheduler, | ||
| 863 | optimizer=optimizer, | ||
| 864 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 865 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 866 | min_lr=args.lr_min_lr, | ||
| 867 | warmup_func=args.lr_warmup_func, | ||
| 868 | annealing_func=args.lr_annealing_func, | ||
| 869 | warmup_exp=args.lr_warmup_exp, | ||
| 870 | annealing_exp=args.lr_annealing_exp, | ||
| 871 | cycles=args.lr_cycles, | ||
| 872 | end_lr=1e3, | ||
| 873 | train_epochs=num_train_epochs, | ||
| 874 | warmup_epochs=args.lr_warmup_epochs, | ||
| 875 | mid_point=args.lr_mid_point, | ||
| 876 | ) | ||
| 877 | 873 | ||
| 878 | training_iter = 0 | 874 | training_iter = 0 |
| 879 | 875 | ||
| @@ -888,6 +884,28 @@ def main(): | |||
| 888 | print(f"------------ TI cycle {training_iter} ------------") | 884 | print(f"------------ TI cycle {training_iter} ------------") |
| 889 | print("") | 885 | print("") |
| 890 | 886 | ||
| 887 | optimizer = create_optimizer( | ||
| 888 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
| 889 | lr=args.learning_rate, | ||
| 890 | ) | ||
| 891 | |||
| 892 | lr_scheduler = get_scheduler( | ||
| 893 | args.lr_scheduler, | ||
| 894 | optimizer=optimizer, | ||
| 895 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 896 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 897 | min_lr=args.lr_min_lr, | ||
| 898 | warmup_func=args.lr_warmup_func, | ||
| 899 | annealing_func=args.lr_annealing_func, | ||
| 900 | warmup_exp=args.lr_warmup_exp, | ||
| 901 | annealing_exp=args.lr_annealing_exp, | ||
| 902 | cycles=args.lr_cycles, | ||
| 903 | end_lr=1e3, | ||
| 904 | train_epochs=num_train_epochs, | ||
| 905 | warmup_epochs=args.lr_warmup_epochs, | ||
| 906 | mid_point=args.lr_mid_point, | ||
| 907 | ) | ||
| 908 | |||
| 891 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 909 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" |
| 892 | sample_output_dir = output_dir / project / "samples" | 910 | sample_output_dir = output_dir / project / "samples" |
| 893 | checkpoint_output_dir = output_dir / project / "checkpoints" | 911 | checkpoint_output_dir = output_dir / project / "checkpoints" |
| @@ -908,10 +926,6 @@ def main(): | |||
| 908 | placeholder_token_ids=placeholder_token_ids, | 926 | placeholder_token_ids=placeholder_token_ids, |
| 909 | ) | 927 | ) |
| 910 | 928 | ||
| 911 | response = input("Run another cycle? [y/n] ") | ||
| 912 | continue_training = response.lower().strip() != "n" | ||
| 913 | training_iter += 1 | ||
| 914 | |||
| 915 | if not args.sequential: | 929 | if not args.sequential: |
| 916 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 930 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 917 | else: | 931 | else: |
diff --git a/training/functional.py b/training/functional.py index e14aeea..46d25f6 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -644,11 +644,9 @@ def train( | |||
| 644 | min_snr_gamma: int = 5, | 644 | min_snr_gamma: int = 5, |
| 645 | **kwargs, | 645 | **kwargs, |
| 646 | ): | 646 | ): |
| 647 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 647 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
| 648 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) | 648 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) |
| 649 | 649 | ||
| 650 | kwargs.update(extra) | ||
| 651 | |||
| 652 | vae.to(accelerator.device, dtype=dtype) | 650 | vae.to(accelerator.device, dtype=dtype) |
| 653 | vae.requires_grad_(False) | 651 | vae.requires_grad_(False) |
| 654 | vae.eval() | 652 | vae.eval() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 695174a..42624cd 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -198,7 +198,7 @@ def dreambooth_prepare( | |||
| 198 | 198 | ||
| 199 | text_encoder.text_model.embeddings.requires_grad_(False) | 199 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 200 | 200 | ||
| 201 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 201 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 202 | 202 | ||
| 203 | 203 | ||
| 204 | dreambooth_strategy = TrainingStrategy( | 204 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ae85401..73ec8f2 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -184,7 +184,7 @@ def lora_prepare( | |||
| 184 | 184 | ||
| 185 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 185 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) |
| 186 | 186 | ||
| 187 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 187 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 188 | 188 | ||
| 189 | 189 | ||
| 190 | lora_strategy = TrainingStrategy( | 190 | lora_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9cdc1bb..363c3f9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -207,7 +207,7 @@ def textual_inversion_prepare( | |||
| 207 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 207 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 208 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 208 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 209 | 209 | ||
| 210 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 210 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 211 | 211 | ||
| 212 | 212 | ||
| 213 | textual_inversion_strategy = TrainingStrategy( | 213 | textual_inversion_strategy = TrainingStrategy( |
