diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 23 | 
1 files changed, 18 insertions, 5 deletions
| diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,6 +3,7 @@ import datetime | |||
| 3 | import logging | 3 | import logging | 
| 4 | from functools import partial | 4 | from functools import partial | 
| 5 | from pathlib import Path | 5 | from pathlib import Path | 
| 6 | from typing import Union | ||
| 6 | import math | 7 | import math | 
| 7 | 8 | ||
| 8 | import torch | 9 | import torch | 
| @@ -75,6 +76,12 @@ def parse_args(): | |||
| 75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." | 
| 76 | ) | 77 | ) | 
| 77 | parser.add_argument( | 78 | parser.add_argument( | 
| 79 | "--filter_tokens", | ||
| 80 | type=str, | ||
| 81 | nargs='*', | ||
| 82 | help="Tokens to filter the dataset by." | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 78 | "--initializer_noise", | 85 | "--initializer_noise", | 
| 79 | type=float, | 86 | type=float, | 
| 80 | default=0, | 87 | default=0, | 
| @@ -538,6 +545,12 @@ def parse_args(): | |||
| 538 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 545 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 
| 539 | raise ValueError("--alias_tokens must be a list with an even number of items") | 546 | raise ValueError("--alias_tokens must be a list with an even number of items") | 
| 540 | 547 | ||
| 548 | if args.filter_tokens is None: | ||
| 549 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 550 | |||
| 551 | if isinstance(args.filter_tokens, str): | ||
| 552 | args.filter_tokens = [args.filter_tokens] | ||
| 553 | |||
| 541 | if args.sequential: | 554 | if args.sequential: | 
| 542 | args.alias_tokens += [ | 555 | args.alias_tokens += [ | 
| 543 | item | 556 | item | 
| @@ -779,13 +792,11 @@ def main(): | |||
| 779 | sample_image_size=args.sample_image_size, | 792 | sample_image_size=args.sample_image_size, | 
| 780 | ) | 793 | ) | 
| 781 | 794 | ||
| 782 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 
| 783 | if len(placeholder_tokens) == 1: | 796 | if len(placeholder_tokens) == 1: | 
| 784 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | 
| 785 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" | ||
| 786 | else: | 798 | else: | 
| 787 | sample_output_dir = output_dir / "samples" | 799 | sample_output_dir = output_dir / "samples" | 
| 788 | metrics_output_file = output_dir / "lr.png" | ||
| 789 | 800 | ||
| 790 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 
| 791 | tokenizer=tokenizer, | 802 | tokenizer=tokenizer, | 
| @@ -800,6 +811,8 @@ def main(): | |||
| 800 | 811 | ||
| 801 | print(f"{i + 1}: {stats}") | 812 | print(f"{i + 1}: {stats}") | 
| 802 | 813 | ||
| 814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | ||
| 815 | |||
| 803 | datamodule = VlpnDataModule( | 816 | datamodule = VlpnDataModule( | 
| 804 | data_file=args.train_data_file, | 817 | data_file=args.train_data_file, | 
| 805 | batch_size=args.train_batch_size, | 818 | batch_size=args.train_batch_size, | 
| @@ -820,7 +833,7 @@ def main(): | |||
| 820 | train_set_pad=args.train_set_pad, | 833 | train_set_pad=args.train_set_pad, | 
| 821 | valid_set_pad=args.valid_set_pad, | 834 | valid_set_pad=args.valid_set_pad, | 
| 822 | seed=args.seed, | 835 | seed=args.seed, | 
| 823 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), | 836 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 
| 824 | dtype=weight_dtype | 837 | dtype=weight_dtype | 
| 825 | ) | 838 | ) | 
| 826 | datamodule.setup() | 839 | datamodule.setup() | 
| @@ -834,7 +847,7 @@ def main(): | |||
| 834 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 847 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 
| 835 | 848 | ||
| 836 | optimizer = create_optimizer( | 849 | optimizer = create_optimizer( | 
| 837 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 850 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 
| 838 | lr=args.learning_rate, | 851 | lr=args.learning_rate, | 
| 839 | ) | 852 | ) | 
| 840 | 853 | ||
