diff options
| -rw-r--r-- | models/clip/embeddings.py | 2 | ||||
| -rw-r--r-- | models/sparse.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 32 | ||||
| -rw-r--r-- | train_ti.py | 23 | ||||
| -rw-r--r-- | training/strategy/lora.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 12 |
6 files changed, 49 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 63a141f..6fda33c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 96 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 97 | 97 | ||
| 98 | def persist(self): | 98 | def persist(self): |
| 99 | input_ids = torch.arange(self.token_embedding.num_embeddings) | 99 | input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) |
| 100 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
| 101 | if embs is not None: | 101 | if embs is not None: |
| 102 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
diff --git a/models/sparse.py b/models/sparse.py index 8910316..d706db5 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
| 12 | self.dtype = dtype | 12 | self.dtype = dtype |
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
| 15 | 15 | ||
| 16 | def forward(self, input_ids: torch.LongTensor): | 16 | def forward(self, input_ids: torch.LongTensor): |
| 17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
diff --git a/train_lora.py b/train_lora.py index 1626be6..e4b5546 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -93,6 +93,12 @@ def parse_args(): | |||
| 93 | help="A token to use as initializer word." | 93 | help="A token to use as initializer word." |
| 94 | ) | 94 | ) |
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--filter_tokens", | ||
| 97 | type=str, | ||
| 98 | nargs='*', | ||
| 99 | help="Tokens to filter the dataset by." | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 96 | "--initializer_noise", | 102 | "--initializer_noise", |
| 97 | type=float, | 103 | type=float, |
| 98 | default=0, | 104 | default=0, |
| @@ -592,6 +598,12 @@ def parse_args(): | |||
| 592 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 598 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
| 593 | raise ValueError("--alias_tokens must be a list with an even number of items") | 599 | raise ValueError("--alias_tokens must be a list with an even number of items") |
| 594 | 600 | ||
| 601 | if args.filter_tokens is None: | ||
| 602 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 603 | |||
| 604 | if isinstance(args.filter_tokens, str): | ||
| 605 | args.filter_tokens = [args.filter_tokens] | ||
| 606 | |||
| 595 | if isinstance(args.collection, str): | 607 | if isinstance(args.collection, str): |
| 596 | args.collection = [args.collection] | 608 | args.collection = [args.collection] |
| 597 | 609 | ||
| @@ -890,7 +902,7 @@ def main(): | |||
| 890 | 902 | ||
| 891 | pti_datamodule = create_datamodule( | 903 | pti_datamodule = create_datamodule( |
| 892 | batch_size=args.pti_batch_size, | 904 | batch_size=args.pti_batch_size, |
| 893 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 905 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), |
| 894 | ) | 906 | ) |
| 895 | pti_datamodule.setup() | 907 | pti_datamodule.setup() |
| 896 | 908 | ||
| @@ -906,7 +918,7 @@ def main(): | |||
| 906 | pti_optimizer = create_optimizer( | 918 | pti_optimizer = create_optimizer( |
| 907 | [ | 919 | [ |
| 908 | { | 920 | { |
| 909 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 921 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 910 | "lr": args.learning_rate_pti, | 922 | "lr": args.learning_rate_pti, |
| 911 | "weight_decay": 0, | 923 | "weight_decay": 0, |
| 912 | }, | 924 | }, |
| @@ -937,7 +949,7 @@ def main(): | |||
| 937 | sample_frequency=pti_sample_frequency, | 949 | sample_frequency=pti_sample_frequency, |
| 938 | ) | 950 | ) |
| 939 | 951 | ||
| 940 | # embeddings.persist() | 952 | embeddings.persist() |
| 941 | 953 | ||
| 942 | # LORA | 954 | # LORA |
| 943 | # -------------------------------------------------------------------------------- | 955 | # -------------------------------------------------------------------------------- |
| @@ -962,13 +974,13 @@ def main(): | |||
| 962 | 974 | ||
| 963 | params_to_optimize = [] | 975 | params_to_optimize = [] |
| 964 | group_labels = [] | 976 | group_labels = [] |
| 965 | if len(args.placeholder_tokens) != 0: | 977 | # if len(args.placeholder_tokens) != 0: |
| 966 | params_to_optimize.append({ | 978 | # params_to_optimize.append({ |
| 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 979 | # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 968 | "lr": args.learning_rate_text, | 980 | # "lr": args.learning_rate_text, |
| 969 | "weight_decay": 0, | 981 | # "weight_decay": 0, |
| 970 | }) | 982 | # }) |
| 971 | group_labels.append("emb") | 983 | # group_labels.append("emb") |
| 972 | params_to_optimize += [ | 984 | params_to_optimize += [ |
| 973 | { | 985 | { |
| 974 | "params": ( | 986 | "params": ( |
diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,6 +3,7 @@ import datetime | |||
| 3 | import logging | 3 | import logging |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import Union | ||
| 6 | import math | 7 | import math |
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| @@ -75,6 +76,12 @@ def parse_args(): | |||
| 75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 79 | "--filter_tokens", | ||
| 80 | type=str, | ||
| 81 | nargs='*', | ||
| 82 | help="Tokens to filter the dataset by." | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 78 | "--initializer_noise", | 85 | "--initializer_noise", |
| 79 | type=float, | 86 | type=float, |
| 80 | default=0, | 87 | default=0, |
| @@ -538,6 +545,12 @@ def parse_args(): | |||
| 538 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 545 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
| 539 | raise ValueError("--alias_tokens must be a list with an even number of items") | 546 | raise ValueError("--alias_tokens must be a list with an even number of items") |
| 540 | 547 | ||
| 548 | if args.filter_tokens is None: | ||
| 549 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 550 | |||
| 551 | if isinstance(args.filter_tokens, str): | ||
| 552 | args.filter_tokens = [args.filter_tokens] | ||
| 553 | |||
| 541 | if args.sequential: | 554 | if args.sequential: |
| 542 | args.alias_tokens += [ | 555 | args.alias_tokens += [ |
| 543 | item | 556 | item |
| @@ -779,13 +792,11 @@ def main(): | |||
| 779 | sample_image_size=args.sample_image_size, | 792 | sample_image_size=args.sample_image_size, |
| 780 | ) | 793 | ) |
| 781 | 794 | ||
| 782 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 783 | if len(placeholder_tokens) == 1: | 796 | if len(placeholder_tokens) == 1: |
| 784 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
| 785 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" | ||
| 786 | else: | 798 | else: |
| 787 | sample_output_dir = output_dir / "samples" | 799 | sample_output_dir = output_dir / "samples" |
| 788 | metrics_output_file = output_dir / "lr.png" | ||
| 789 | 800 | ||
| 790 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 791 | tokenizer=tokenizer, | 802 | tokenizer=tokenizer, |
| @@ -800,6 +811,8 @@ def main(): | |||
| 800 | 811 | ||
| 801 | print(f"{i + 1}: {stats}") | 812 | print(f"{i + 1}: {stats}") |
| 802 | 813 | ||
| 814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | ||
| 815 | |||
| 803 | datamodule = VlpnDataModule( | 816 | datamodule = VlpnDataModule( |
| 804 | data_file=args.train_data_file, | 817 | data_file=args.train_data_file, |
| 805 | batch_size=args.train_batch_size, | 818 | batch_size=args.train_batch_size, |
| @@ -820,7 +833,7 @@ def main(): | |||
| 820 | train_set_pad=args.train_set_pad, | 833 | train_set_pad=args.train_set_pad, |
| 821 | valid_set_pad=args.valid_set_pad, | 834 | valid_set_pad=args.valid_set_pad, |
| 822 | seed=args.seed, | 835 | seed=args.seed, |
| 823 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), | 836 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
| 824 | dtype=weight_dtype | 837 | dtype=weight_dtype |
| 825 | ) | 838 | ) |
| 826 | datamodule.setup() | 839 | datamodule.setup() |
| @@ -834,7 +847,7 @@ def main(): | |||
| 834 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 847 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 835 | 848 | ||
| 836 | optimizer = create_optimizer( | 849 | optimizer = create_optimizer( |
| 837 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 850 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 838 | lr=args.learning_rate, | 851 | lr=args.learning_rate, |
| 839 | ) | 852 | ) |
| 840 | 853 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cfdc504..ae85401 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
| 93 | if use_emb_decay: | 93 | if use_emb_decay: |
| 94 | params = [ | 94 | params = [ |
| 95 | p | 95 | p |
| 96 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() |
| 97 | if p.grad is not None | 97 | if p.grad is not None |
| 98 | ] | 98 | ] |
| 99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 720ebf3..289d6bd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
| 72 | 72 | ||
| 73 | if use_ema: | 73 | if use_ema: |
| 74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
| 75 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
| 77 | power=ema_power, | 77 | power=ema_power, |
| 78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
| @@ -84,20 +84,20 @@ def textual_inversion_strategy_callbacks( | |||
| 84 | def ema_context(): | 84 | def ema_context(): |
| 85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
| 86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
| 87 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters() | 87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() |
| 88 | ) | 88 | ) |
| 89 | else: | 89 | else: |
| 90 | return nullcontext() | 90 | return nullcontext() |
| 91 | 91 | ||
| 92 | @contextmanager | 92 | @contextmanager |
| 93 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
| 94 | text_encoder.text_model.embeddings.token_override_embedding.params.train() | 94 | text_encoder.train() |
| 95 | tokenizer.train() | 95 | tokenizer.train() |
| 96 | yield | 96 | yield |
| 97 | 97 | ||
| 98 | @contextmanager | 98 | @contextmanager |
| 99 | def on_eval(): | 99 | def on_eval(): |
| 100 | text_encoder.text_model.embeddings.token_override_embedding.params.eval() | 100 | text_encoder.eval() |
| 101 | tokenizer.eval() | 101 | tokenizer.eval() |
| 102 | 102 | ||
| 103 | with ema_context(): | 103 | with ema_context(): |
| @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| 111 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() |
| 112 | if p.grad is not None | 112 | if p.grad is not None |
| 113 | ] | 113 | ] |
| 114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
| @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) |
| 120 | 120 | ||
| 121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] or lrs["0"] | 122 | lr = lrs["emb"] or lrs["0"] |
