summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.py24
-rw-r--r--models/clip/embeddings.py30
-rw-r--r--models/clip/tokenizer.py6
-rw-r--r--train_ti.py1
4 files changed, 35 insertions, 26 deletions
diff --git a/common.py b/common.py
index 691be4e..0887197 100644
--- a/common.py
+++ b/common.py
@@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
24 return [] 24 return []
25 25
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] 26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
27 tokens = [filename.stem for filename in filenames]
28 27
29 for filename in embeddings_dir.iterdir(): 28 new_tokens = []
30 if filename.is_file(): 29 new_embeds = []
31 with safe_open(filename, framework="pt", device="cpu") as file:
32 embed = file.get_tensor("embed")
33 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
34 embeddings.add_embed(added.ids, embed)
35 30
36 return tokens 31 for filename in filenames:
32 with safe_open(filename, framework="pt", device="cpu") as file:
33 embed = file.get_tensor("embed")
34
35 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
36 new_tokens.append(added)
37 new_embeds.append(embed)
38
39 embeddings.resize(len(tokenizer))
40
41 for (new_token, embeds) in zip(new_tokens, new_embeds):
42 embeddings.add_embed(new_token.ids, embeds)
43
44 return new_tokens
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 91a575d..cab1515 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -12,18 +12,22 @@ from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14 14
15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: 15def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size() 16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size()
17 17
18 if old_num_embeddings == new_num_embeddings:
19 return old_embedding
20
21 n = min(old_num_embeddings, new_num_embeddings)
22
18 new_embedding = nn.Embedding( 23 new_embedding = nn.Embedding(
19 old_num_embeddings + n, 24 new_num_embeddings,
20 old_embedding_dim, 25 old_embedding_dim,
21 device=old_embedding.weight.device, 26 device=old_embedding.weight.device,
22 dtype=old_embedding.weight.dtype 27 dtype=old_embedding.weight.dtype
23 ) 28 )
24 new_embedding.weight.data.zero_() 29 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
25 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data 30 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
26
27 return new_embedding 31 return new_embedding
28 32
29 33
@@ -40,9 +44,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 device=self.token_embedding.weight.device, 44 device=self.token_embedding.weight.device,
41 dtype=self.token_embedding.weight.dtype 45 dtype=self.token_embedding.weight.dtype
42 ) 46 )
43 self.temp_token_embedding.weight.data.zero_() 47 self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02)
44 self.temp_token_ids = torch.tensor([], dtype=torch.long) 48 self.temp_token_ids = torch.tensor([], dtype=torch.long)
45 49
50 def resize(self, size: int):
51 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor)
52 self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor)
53
46 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 54 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
47 if isinstance(token_ids, int): 55 if isinstance(token_ids, int):
48 token_ids = [token_ids] 56 token_ids = [token_ids]
@@ -55,20 +63,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
55 initializer = (initializer * len(token_ids))[:len(token_ids)] 63 initializer = (initializer * len(token_ids))[:len(token_ids)]
56 64
57 with torch.no_grad(): 65 with torch.no_grad():
58 initializer = self.get_embed(initializer) 66 initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype)
59
60 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
61 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))
62 67
63 token_ids = torch.tensor(token_ids, dtype=torch.long) 68 token_ids = torch.tensor(token_ids, dtype=torch.long)
64 69
65 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 70 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
66 71
67 if initializer is not None: 72 if initializer is not None:
68 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 73 self.temp_token_embedding.weight.data[token_ids] = initializer
69 dtype=self.temp_token_embedding.weight.dtype)
70 else:
71 self.temp_token_embedding.weight.data[token_ids].zero_()
72 74
73 def load_embed(self, input_ids: list[int], filename: Path): 75 def load_embed(self, input_ids: list[int], filename: Path):
74 with safe_open(filename, framework="pt", device="cpu") as file: 76 with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 63566e0..fbfe790 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -8,7 +8,6 @@ from transformers import CLIPTokenizer
8 8
9class MultiCLIPTokenizerItem(NamedTuple): 9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str 10 token: str
11 meta_id: int
12 ids: list[int] 11 ids: list[int]
13 12
14 13
@@ -38,11 +37,10 @@ class MultiCLIPTokenizer(CLIPTokenizer):
38 super().add_tokens(multi_token) 37 super().add_tokens(multi_token)
39 38
40 ids = super().convert_tokens_to_ids(multi_token) 39 ids = super().convert_tokens_to_ids(multi_token)
41 meta_id = ids[0]
42 40
43 self.token_map[meta_id] = ids 41 self.token_map[ids[0]] = ids
44 42
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) 43 return MultiCLIPTokenizerItem(new_tokens, ids)
46 44
47 def expand_id(self, id: int, vector_shuffle=True): 45 def expand_id(self, id: int, vector_shuffle=True):
48 if id in self.token_map: 46 if id in self.token_map:
diff --git a/train_ti.py b/train_ti.py
index 3776eb2..19348e5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -535,6 +535,7 @@ def main():
535 ] 535 ]
536 536
537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
538 embeddings.resize(len(tokenizer))
538 539
539 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 540 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
540 embeddings.add_embed(new_token.ids, init_ids) 541 embeddings.add_embed(new_token.ids, init_ids)