summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--models/sparse.py2
-rw-r--r--train_lora.py32
-rw-r--r--train_ti.py23
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py12
6 files changed, 49 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 63a141f..6fda33c 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 save_file({"embed": self.get_embed(input_ids)}, filename) 96 save_file({"embed": self.get_embed(input_ids)}, filename)
97 97
98 def persist(self): 98 def persist(self):
99 input_ids = torch.arange(self.token_embedding.num_embeddings) 99 input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device)
100 embs, mask = self.token_override_embedding(input_ids) 100 embs, mask = self.token_override_embedding(input_ids)
101 if embs is not None: 101 if embs is not None:
102 input_ids = input_ids[mask] 102 input_ids = input_ids[mask]
diff --git a/models/sparse.py b/models/sparse.py
index 8910316..d706db5 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module):
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 15
16 def forward(self, input_ids: torch.LongTensor): 16 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]
diff --git a/train_lora.py b/train_lora.py
index 1626be6..e4b5546 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -93,6 +93,12 @@ def parse_args():
93 help="A token to use as initializer word." 93 help="A token to use as initializer word."
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--filter_tokens",
97 type=str,
98 nargs='*',
99 help="Tokens to filter the dataset by."
100 )
101 parser.add_argument(
96 "--initializer_noise", 102 "--initializer_noise",
97 type=float, 103 type=float,
98 default=0, 104 default=0,
@@ -592,6 +598,12 @@ def parse_args():
592 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 598 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
593 raise ValueError("--alias_tokens must be a list with an even number of items") 599 raise ValueError("--alias_tokens must be a list with an even number of items")
594 600
601 if args.filter_tokens is None:
602 args.filter_tokens = args.placeholder_tokens.copy()
603
604 if isinstance(args.filter_tokens, str):
605 args.filter_tokens = [args.filter_tokens]
606
595 if isinstance(args.collection, str): 607 if isinstance(args.collection, str):
596 args.collection = [args.collection] 608 args.collection = [args.collection]
597 609
@@ -890,7 +902,7 @@ def main():
890 902
891 pti_datamodule = create_datamodule( 903 pti_datamodule = create_datamodule(
892 batch_size=args.pti_batch_size, 904 batch_size=args.pti_batch_size,
893 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 905 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections),
894 ) 906 )
895 pti_datamodule.setup() 907 pti_datamodule.setup()
896 908
@@ -906,7 +918,7 @@ def main():
906 pti_optimizer = create_optimizer( 918 pti_optimizer = create_optimizer(
907 [ 919 [
908 { 920 {
909 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 921 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
910 "lr": args.learning_rate_pti, 922 "lr": args.learning_rate_pti,
911 "weight_decay": 0, 923 "weight_decay": 0,
912 }, 924 },
@@ -937,7 +949,7 @@ def main():
937 sample_frequency=pti_sample_frequency, 949 sample_frequency=pti_sample_frequency,
938 ) 950 )
939 951
940 # embeddings.persist() 952 embeddings.persist()
941 953
942 # LORA 954 # LORA
943 # -------------------------------------------------------------------------------- 955 # --------------------------------------------------------------------------------
@@ -962,13 +974,13 @@ def main():
962 974
963 params_to_optimize = [] 975 params_to_optimize = []
964 group_labels = [] 976 group_labels = []
965 if len(args.placeholder_tokens) != 0: 977 # if len(args.placeholder_tokens) != 0:
966 params_to_optimize.append({ 978 # params_to_optimize.append({
967 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 979 # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
968 "lr": args.learning_rate_text, 980 # "lr": args.learning_rate_text,
969 "weight_decay": 0, 981 # "weight_decay": 0,
970 }) 982 # })
971 group_labels.append("emb") 983 # group_labels.append("emb")
972 params_to_optimize += [ 984 params_to_optimize += [
973 { 985 {
974 "params": ( 986 "params": (
diff --git a/train_ti.py b/train_ti.py
index 48858cc..daf8bc5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,6 +3,7 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import Union
6import math 7import math
7 8
8import torch 9import torch
@@ -75,6 +76,12 @@ def parse_args():
75 help="A token to use as initializer word." 76 help="A token to use as initializer word."
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
79 "--filter_tokens",
80 type=str,
81 nargs='*',
82 help="Tokens to filter the dataset by."
83 )
84 parser.add_argument(
78 "--initializer_noise", 85 "--initializer_noise",
79 type=float, 86 type=float,
80 default=0, 87 default=0,
@@ -538,6 +545,12 @@ def parse_args():
538 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 545 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
539 raise ValueError("--alias_tokens must be a list with an even number of items") 546 raise ValueError("--alias_tokens must be a list with an even number of items")
540 547
548 if args.filter_tokens is None:
549 args.filter_tokens = args.placeholder_tokens.copy()
550
551 if isinstance(args.filter_tokens, str):
552 args.filter_tokens = [args.filter_tokens]
553
541 if args.sequential: 554 if args.sequential:
542 args.alias_tokens += [ 555 args.alias_tokens += [
543 item 556 item
@@ -779,13 +792,11 @@ def main():
779 sample_image_size=args.sample_image_size, 792 sample_image_size=args.sample_image_size,
780 ) 793 )
781 794
782 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 795 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
783 if len(placeholder_tokens) == 1: 796 if len(placeholder_tokens) == 1:
784 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" 797 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
785 metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png"
786 else: 798 else:
787 sample_output_dir = output_dir / "samples" 799 sample_output_dir = output_dir / "samples"
788 metrics_output_file = output_dir / "lr.png"
789 800
790 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
791 tokenizer=tokenizer, 802 tokenizer=tokenizer,
@@ -800,6 +811,8 @@ def main():
800 811
801 print(f"{i + 1}: {stats}") 812 print(f"{i + 1}: {stats}")
802 813
814 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
815
803 datamodule = VlpnDataModule( 816 datamodule = VlpnDataModule(
804 data_file=args.train_data_file, 817 data_file=args.train_data_file,
805 batch_size=args.train_batch_size, 818 batch_size=args.train_batch_size,
@@ -820,7 +833,7 @@ def main():
820 train_set_pad=args.train_set_pad, 833 train_set_pad=args.train_set_pad,
821 valid_set_pad=args.valid_set_pad, 834 valid_set_pad=args.valid_set_pad,
822 seed=args.seed, 835 seed=args.seed,
823 filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), 836 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
824 dtype=weight_dtype 837 dtype=weight_dtype
825 ) 838 )
826 datamodule.setup() 839 datamodule.setup()
@@ -834,7 +847,7 @@ def main():
834 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 847 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
835 848
836 optimizer = create_optimizer( 849 optimizer = create_optimizer(
837 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 850 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
838 lr=args.learning_rate, 851 lr=args.learning_rate,
839 ) 852 )
840 853
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index cfdc504..ae85401 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -93,7 +93,7 @@ def lora_strategy_callbacks(
93 if use_emb_decay: 93 if use_emb_decay:
94 params = [ 94 params = [
95 p 95 p
96 for p in text_encoder.text_model.embeddings.token_override_embedding.params 96 for p in text_encoder.text_model.embeddings.token_override_embedding.parameters()
97 if p.grad is not None 97 if p.grad is not None
98 ] 98 ]
99 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 720ebf3..289d6bd 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks(
72 72
73 if use_ema: 73 if use_ema:
74 ema_embeddings = EMAModel( 74 ema_embeddings = EMAModel(
75 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 75 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
76 inv_gamma=ema_inv_gamma, 76 inv_gamma=ema_inv_gamma,
77 power=ema_power, 77 power=ema_power,
78 max_value=ema_max_decay, 78 max_value=ema_max_decay,
@@ -84,20 +84,20 @@ def textual_inversion_strategy_callbacks(
84 def ema_context(): 84 def ema_context():
85 if ema_embeddings is not None: 85 if ema_embeddings is not None:
86 return ema_embeddings.apply_temporary( 86 return ema_embeddings.apply_temporary(
87 text_encoder.text_model.embeddings.token_override_embedding.params.parameters() 87 text_encoder.text_model.embeddings.token_override_embedding.parameters()
88 ) 88 )
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(epoch: int):
94 text_encoder.text_model.embeddings.token_override_embedding.params.train() 94 text_encoder.train()
95 tokenizer.train() 95 tokenizer.train()
96 yield 96 yield
97 97
98 @contextmanager 98 @contextmanager
99 def on_eval(): 99 def on_eval():
100 text_encoder.text_model.embeddings.token_override_embedding.params.eval() 100 text_encoder.eval()
101 tokenizer.eval() 101 tokenizer.eval()
102 102
103 with ema_context(): 103 with ema_context():
@@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks(
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
111 for p in text_encoder.text_model.embeddings.token_override_embedding.params 111 for p in text_encoder.text_model.embeddings.token_override_embedding.parameters()
112 if p.grad is not None 112 if p.grad is not None
113 ] 113 ]
114 return torch.stack(params) if len(params) != 0 else None 114 return torch.stack(params) if len(params) != 0 else None
@@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"] 122 lr = lrs["emb"] or lrs["0"]