diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -3,6 +3,7 @@ import datetime | |||
3 | import logging | 3 | import logging |
4 | from functools import partial | 4 | from functools import partial |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from typing import Union | ||
6 | import math | 7 | import math |
7 | 8 | ||
8 | import torch | 9 | import torch |
@@ -75,6 +76,12 @@ def parse_args(): | |||
75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
76 | ) | 77 | ) |
77 | parser.add_argument( | 78 | parser.add_argument( |
79 | "--filter_tokens", | ||
80 | type=str, | ||
81 | nargs='*', | ||
82 | help="Tokens to filter the dataset by." | ||
83 | ) | ||
84 | parser.add_argument( | ||
78 | "--initializer_noise", | 85 | "--initializer_noise", |
79 | type=float, | 86 | type=float, |
80 | default=0, | 87 | default=0, |
@@ -538,6 +545,12 @@ def parse_args(): | |||
538 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 545 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
539 | raise ValueError("--alias_tokens must be a list with an even number of items") | 546 | raise ValueError("--alias_tokens must be a list with an even number of items") |
540 | 547 | ||
548 | if args.filter_tokens is None: | ||
549 | args.filter_tokens = args.placeholder_tokens.copy() | ||
550 | |||
551 | if isinstance(args.filter_tokens, str): | ||
552 | args.filter_tokens = [args.filter_tokens] | ||
553 | |||
541 | if args.sequential: | 554 | if args.sequential: |
542 | args.alias_tokens += [ | 555 | args.alias_tokens += [ |
543 | item | 556 | item |
@@ -779,13 +792,11 @@ def main(): | |||
779 | sample_image_size=args.sample_image_size, | 792 | sample_image_size=args.sample_image_size, |
780 | ) | 793 | ) |
781 | 794 | ||
782 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
783 | if len(placeholder_tokens) == 1: | 796 | if len(placeholder_tokens) == 1: |
784 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
785 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" | ||
786 | else: | 798 | else: |
787 | sample_output_dir = output_dir / "samples" | 799 | sample_output_dir = output_dir / "samples" |
788 | metrics_output_file = output_dir / "lr.png" | ||
789 | 800 | ||
790 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
791 | tokenizer=tokenizer, | 802 | tokenizer=tokenizer, |
@@ -800,6 +811,8 @@ def main(): | |||
800 | 811 | ||
801 | print(f"{i + 1}: {stats}") | 812 | print(f"{i + 1}: {stats}") |
802 | 813 | ||
814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | ||
815 | |||
803 | datamodule = VlpnDataModule( | 816 | datamodule = VlpnDataModule( |
804 | data_file=args.train_data_file, | 817 | data_file=args.train_data_file, |
805 | batch_size=args.train_batch_size, | 818 | batch_size=args.train_batch_size, |
@@ -820,7 +833,7 @@ def main(): | |||
820 | train_set_pad=args.train_set_pad, | 833 | train_set_pad=args.train_set_pad, |
821 | valid_set_pad=args.valid_set_pad, | 834 | valid_set_pad=args.valid_set_pad, |
822 | seed=args.seed, | 835 | seed=args.seed, |
823 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), | 836 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
824 | dtype=weight_dtype | 837 | dtype=weight_dtype |
825 | ) | 838 | ) |
826 | datamodule.setup() | 839 | datamodule.setup() |
@@ -834,7 +847,7 @@ def main(): | |||
834 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 847 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
835 | 848 | ||
836 | optimizer = create_optimizer( | 849 | optimizer = create_optimizer( |
837 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 850 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
838 | lr=args.learning_rate, | 851 | lr=args.learning_rate, |
839 | ) | 852 | ) |
840 | 853 | ||