diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,6 +3,7 @@ import datetime | |||
| 3 | import logging | 3 | import logging |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import Union | ||
| 6 | import math | 7 | import math |
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| @@ -75,6 +76,12 @@ def parse_args(): | |||
| 75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 79 | "--filter_tokens", | ||
| 80 | type=str, | ||
| 81 | nargs='*', | ||
| 82 | help="Tokens to filter the dataset by." | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 78 | "--initializer_noise", | 85 | "--initializer_noise", |
| 79 | type=float, | 86 | type=float, |
| 80 | default=0, | 87 | default=0, |
| @@ -538,6 +545,12 @@ def parse_args(): | |||
| 538 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 545 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
| 539 | raise ValueError("--alias_tokens must be a list with an even number of items") | 546 | raise ValueError("--alias_tokens must be a list with an even number of items") |
| 540 | 547 | ||
| 548 | if args.filter_tokens is None: | ||
| 549 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 550 | |||
| 551 | if isinstance(args.filter_tokens, str): | ||
| 552 | args.filter_tokens = [args.filter_tokens] | ||
| 553 | |||
| 541 | if args.sequential: | 554 | if args.sequential: |
| 542 | args.alias_tokens += [ | 555 | args.alias_tokens += [ |
| 543 | item | 556 | item |
| @@ -779,13 +792,11 @@ def main(): | |||
| 779 | sample_image_size=args.sample_image_size, | 792 | sample_image_size=args.sample_image_size, |
| 780 | ) | 793 | ) |
| 781 | 794 | ||
| 782 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 783 | if len(placeholder_tokens) == 1: | 796 | if len(placeholder_tokens) == 1: |
| 784 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
| 785 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" | ||
| 786 | else: | 798 | else: |
| 787 | sample_output_dir = output_dir / "samples" | 799 | sample_output_dir = output_dir / "samples" |
| 788 | metrics_output_file = output_dir / "lr.png" | ||
| 789 | 800 | ||
| 790 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 791 | tokenizer=tokenizer, | 802 | tokenizer=tokenizer, |
| @@ -800,6 +811,8 @@ def main(): | |||
| 800 | 811 | ||
| 801 | print(f"{i + 1}: {stats}") | 812 | print(f"{i + 1}: {stats}") |
| 802 | 813 | ||
| 814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | ||
| 815 | |||
| 803 | datamodule = VlpnDataModule( | 816 | datamodule = VlpnDataModule( |
| 804 | data_file=args.train_data_file, | 817 | data_file=args.train_data_file, |
| 805 | batch_size=args.train_batch_size, | 818 | batch_size=args.train_batch_size, |
| @@ -820,7 +833,7 @@ def main(): | |||
| 820 | train_set_pad=args.train_set_pad, | 833 | train_set_pad=args.train_set_pad, |
| 821 | valid_set_pad=args.valid_set_pad, | 834 | valid_set_pad=args.valid_set_pad, |
| 822 | seed=args.seed, | 835 | seed=args.seed, |
| 823 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), | 836 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
| 824 | dtype=weight_dtype | 837 | dtype=weight_dtype |
| 825 | ) | 838 | ) |
| 826 | datamodule.setup() | 839 | datamodule.setup() |
| @@ -834,7 +847,7 @@ def main(): | |||
| 834 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 847 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 835 | 848 | ||
| 836 | optimizer = create_optimizer( | 849 | optimizer = create_optimizer( |
| 837 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 850 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 838 | lr=args.learning_rate, | 851 | lr=args.learning_rate, |
| 839 | ) | 852 | ) |
| 840 | 853 | ||
