diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/train_ti.py b/train_ti.py index e4fd464..7bcc72f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -3,6 +3,7 @@ import datetime | |||
3 | import logging | 3 | import logging |
4 | from functools import partial | 4 | from functools import partial |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | import math | ||
6 | 7 | ||
7 | import torch | 8 | import torch |
8 | import torch.utils.checkpoint | 9 | import torch.utils.checkpoint |
@@ -207,7 +208,12 @@ def parse_args(): | |||
207 | parser.add_argument( | 208 | parser.add_argument( |
208 | "--num_train_epochs", | 209 | "--num_train_epochs", |
209 | type=int, | 210 | type=int, |
210 | default=100 | 211 | default=None |
212 | ) | ||
213 | parser.add_argument( | ||
214 | "--num_train_steps", | ||
215 | type=int, | ||
216 | default=2000 | ||
211 | ) | 217 | ) |
212 | parser.add_argument( | 218 | parser.add_argument( |
213 | "--gradient_accumulation_steps", | 219 | "--gradient_accumulation_steps", |
@@ -513,13 +519,13 @@ def parse_args(): | |||
513 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 519 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
514 | raise ValueError("--alias_tokens must be a list with an even number of items") | 520 | raise ValueError("--alias_tokens must be a list with an even number of items") |
515 | 521 | ||
516 | args.alias_tokens += [ | ||
517 | item | ||
518 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
519 | for item in pair | ||
520 | ] | ||
521 | |||
522 | if args.sequential: | 522 | if args.sequential: |
523 | args.alias_tokens += [ | ||
524 | item | ||
525 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
526 | for item in pair | ||
527 | ] | ||
528 | |||
523 | if isinstance(args.train_data_template, str): | 529 | if isinstance(args.train_data_template, str): |
524 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 530 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
525 | 531 | ||
@@ -607,6 +613,7 @@ def main(): | |||
607 | raise ValueError("--embeddings_dir must point to an existing directory") | 613 | raise ValueError("--embeddings_dir must point to an existing directory") |
608 | 614 | ||
609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 615 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
616 | embeddings.persist() | ||
610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
611 | 618 | ||
612 | if args.scale_lr: | 619 | if args.scale_lr: |
@@ -682,7 +689,6 @@ def main(): | |||
682 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 689 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
683 | no_val=args.valid_set_size == 0, | 690 | no_val=args.valid_set_size == 0, |
684 | strategy=textual_inversion_strategy, | 691 | strategy=textual_inversion_strategy, |
685 | num_train_epochs=args.num_train_epochs, | ||
686 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 692 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
687 | sample_frequency=args.sample_frequency, | 693 | sample_frequency=args.sample_frequency, |
688 | checkpoint_frequency=args.checkpoint_frequency, | 694 | checkpoint_frequency=args.checkpoint_frequency, |
@@ -752,6 +758,11 @@ def main(): | |||
752 | ) | 758 | ) |
753 | datamodule.setup() | 759 | datamodule.setup() |
754 | 760 | ||
761 | num_train_epochs = args.num_train_epochs | ||
762 | |||
763 | if num_train_epochs is None: | ||
764 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
765 | |||
755 | optimizer = create_optimizer( | 766 | optimizer = create_optimizer( |
756 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 767 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
757 | lr=args.learning_rate, | 768 | lr=args.learning_rate, |
@@ -769,7 +780,7 @@ def main(): | |||
769 | annealing_exp=args.lr_annealing_exp, | 780 | annealing_exp=args.lr_annealing_exp, |
770 | cycles=args.lr_cycles, | 781 | cycles=args.lr_cycles, |
771 | end_lr=1e3, | 782 | end_lr=1e3, |
772 | train_epochs=args.num_train_epochs, | 783 | train_epochs=num_train_epochs, |
773 | warmup_epochs=args.lr_warmup_epochs, | 784 | warmup_epochs=args.lr_warmup_epochs, |
774 | ) | 785 | ) |
775 | 786 | ||
@@ -779,6 +790,7 @@ def main(): | |||
779 | val_dataloader=datamodule.val_dataloader, | 790 | val_dataloader=datamodule.val_dataloader, |
780 | optimizer=optimizer, | 791 | optimizer=optimizer, |
781 | lr_scheduler=lr_scheduler, | 792 | lr_scheduler=lr_scheduler, |
793 | num_train_epochs=num_train_epochs, | ||
782 | # -- | 794 | # -- |
783 | sample_output_dir=sample_output_dir, | 795 | sample_output_dir=sample_output_dir, |
784 | placeholder_tokens=placeholder_tokens, | 796 | placeholder_tokens=placeholder_tokens, |