diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 66 |
1 files changed, 44 insertions, 22 deletions
diff --git a/train_lora.py b/train_lora.py index e81742a..4bbc64e 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -199,6 +199,11 @@ def parse_args(): | |||
199 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 199 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
200 | ) | 200 | ) |
201 | parser.add_argument( | 201 | parser.add_argument( |
202 | "--train_dir_embeddings", | ||
203 | action="store_true", | ||
204 | help="Train embeddings loaded from embeddings directory.", | ||
205 | ) | ||
206 | parser.add_argument( | ||
202 | "--collection", | 207 | "--collection", |
203 | type=str, | 208 | type=str, |
204 | nargs='*', | 209 | nargs='*', |
@@ -440,6 +445,12 @@ def parse_args(): | |||
440 | help="How often to save a checkpoint and sample image", | 445 | help="How often to save a checkpoint and sample image", |
441 | ) | 446 | ) |
442 | parser.add_argument( | 447 | parser.add_argument( |
448 | "--sample_num", | ||
449 | type=int, | ||
450 | default=None, | ||
451 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
452 | ) | ||
453 | parser.add_argument( | ||
443 | "--sample_image_size", | 454 | "--sample_image_size", |
444 | type=int, | 455 | type=int, |
445 | default=768, | 456 | default=768, |
@@ -681,27 +692,36 @@ def main(): | |||
681 | embeddings.persist() | 692 | embeddings.persist() |
682 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 693 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
683 | 694 | ||
695 | placeholder_token_ids = [] | ||
696 | |||
684 | if args.embeddings_dir is not None: | 697 | if args.embeddings_dir is not None: |
685 | embeddings_dir = Path(args.embeddings_dir) | 698 | embeddings_dir = Path(args.embeddings_dir) |
686 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 699 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
687 | raise ValueError("--embeddings_dir must point to an existing directory") | 700 | raise ValueError("--embeddings_dir must point to an existing directory") |
688 | 701 | ||
689 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 702 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
690 | embeddings.persist() | ||
691 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 703 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
692 | 704 | ||
693 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 705 | if args.train_dir_embeddings: |
694 | tokenizer=tokenizer, | 706 | args.placeholder_tokens = added_tokens |
695 | embeddings=embeddings, | 707 | placeholder_token_ids = added_ids |
696 | placeholder_tokens=args.placeholder_tokens, | 708 | print("Training embeddings from embeddings dir") |
697 | initializer_tokens=args.initializer_tokens, | 709 | else: |
698 | num_vectors=args.num_vectors, | 710 | embeddings.persist() |
699 | initializer_noise=args.initializer_noise, | 711 | |
700 | ) | 712 | if not args.train_dir_embeddings: |
701 | stats = list(zip( | 713 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
702 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 714 | tokenizer=tokenizer, |
703 | )) | 715 | embeddings=embeddings, |
704 | print(f"Training embeddings: {stats}") | 716 | placeholder_tokens=args.placeholder_tokens, |
717 | initializer_tokens=args.initializer_tokens, | ||
718 | num_vectors=args.num_vectors, | ||
719 | initializer_noise=args.initializer_noise, | ||
720 | ) | ||
721 | stats = list(zip( | ||
722 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | ||
723 | )) | ||
724 | print(f"Training embeddings: {stats}") | ||
705 | 725 | ||
706 | if args.scale_lr: | 726 | if args.scale_lr: |
707 | args.learning_rate_unet = ( | 727 | args.learning_rate_unet = ( |
@@ -897,6 +917,8 @@ def main(): | |||
897 | args.num_train_steps / len(lora_datamodule.train_dataset) | 917 | args.num_train_steps / len(lora_datamodule.train_dataset) |
898 | ) * args.gradient_accumulation_steps | 918 | ) * args.gradient_accumulation_steps |
899 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 919 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
920 | if args.sample_num is not None: | ||
921 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
900 | 922 | ||
901 | params_to_optimize = [] | 923 | params_to_optimize = [] |
902 | group_labels = [] | 924 | group_labels = [] |
@@ -930,15 +952,6 @@ def main(): | |||
930 | ] | 952 | ] |
931 | group_labels += ["unet", "text"] | 953 | group_labels += ["unet", "text"] |
932 | 954 | ||
933 | lora_optimizer = create_optimizer(params_to_optimize) | ||
934 | |||
935 | lora_lr_scheduler = create_lr_scheduler( | ||
936 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
937 | optimizer=lora_optimizer, | ||
938 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | ||
939 | train_epochs=num_train_epochs, | ||
940 | ) | ||
941 | |||
942 | training_iter = 0 | 955 | training_iter = 0 |
943 | 956 | ||
944 | while True: | 957 | while True: |
@@ -952,6 +965,15 @@ def main(): | |||
952 | print(f"============ LoRA cycle {training_iter} ============") | 965 | print(f"============ LoRA cycle {training_iter} ============") |
953 | print("") | 966 | print("") |
954 | 967 | ||
968 | lora_optimizer = create_optimizer(params_to_optimize) | ||
969 | |||
970 | lora_lr_scheduler = create_lr_scheduler( | ||
971 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
972 | optimizer=lora_optimizer, | ||
973 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | ||
974 | train_epochs=num_train_epochs, | ||
975 | ) | ||
976 | |||
955 | lora_project = f"lora_{training_iter}" | 977 | lora_project = f"lora_{training_iter}" |
956 | lora_checkpoint_output_dir = output_dir / lora_project / "model" | 978 | lora_checkpoint_output_dir = output_dir / lora_project / "model" |
957 | lora_sample_output_dir = output_dir / lora_project / "samples" | 979 | lora_sample_output_dir = output_dir / lora_project / "samples" |