diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 34 |
1 files changed, 15 insertions, 19 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 95904cf..2b315c4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -42,16 +42,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
42 | self.init_temp_embeddings() | 42 | self.init_temp_embeddings() |
43 | 43 | ||
44 | def init_temp_embeddings(self): | 44 | def init_temp_embeddings(self): |
45 | self.temp_token_embedding = nn.ParameterList() | 45 | self.temp_token_embedding = nn.Embedding( |
46 | 0, | ||
47 | self.token_embedding.embedding_dim, | ||
48 | device=self.token_embedding.weight.device, | ||
49 | dtype=self.token_embedding.weight.dtype | ||
50 | ) | ||
46 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
47 | 52 | ||
48 | def resize(self, size: int): | 53 | def resize(self, size: int): |
49 | for _ in range(len(self.temp_token_embedding), size): | 54 | self.temp_token_embedding = resize_embedding( |
50 | self.temp_token_embedding.append(torch.zeros( | 55 | self.temp_token_embedding, |
51 | self.token_embedding.embedding_dim, | 56 | size - self.num_permanent_embeddings, |
52 | device=self.token_embedding.weight.device, | 57 | self.initializer_factor |
53 | dtype=self.token_embedding.weight.dtype, | 58 | ) |
54 | )) | ||
55 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
56 | 60 | ||
57 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
@@ -78,10 +82,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
78 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 82 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
79 | 83 | ||
80 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 84 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
81 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 85 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) |
82 | 86 | self.temp_token_embedding.weight.data[mask] = initializer | |
83 | for i, id in enumerate(mask): | ||
84 | self.temp_token_embedding[id] = initializer[i] | ||
85 | 87 | ||
86 | def load_embed(self, input_ids: list[int], filename: Path): | 88 | def load_embed(self, input_ids: list[int], filename: Path): |
87 | with safe_open(filename, framework="pt", device="cpu") as file: | 89 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -91,8 +93,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
91 | save_file({"embed": self.get_embed(input_ids)}, filename) | 93 | save_file({"embed": self.get_embed(input_ids)}, filename) |
92 | 94 | ||
93 | def persist(self): | 95 | def persist(self): |
94 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): | 96 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
95 | self.token_embedding.weight.data[id] = emb | ||
96 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 97 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
97 | self.init_temp_embeddings() | 98 | self.init_temp_embeddings() |
98 | 99 | ||
@@ -111,12 +112,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
111 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | 112 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) |
112 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | 113 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() |
113 | 114 | ||
114 | if len(temp_token_ids): | 115 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) |
115 | embeds_override = torch.stack([ | ||
116 | self.temp_token_embedding[id] | ||
117 | for id in temp_token_ids | ||
118 | ]) | ||
119 | embeds[embeds_mask] = embeds_override | ||
120 | 116 | ||
121 | return embeds | 117 | return embeds |
122 | 118 | ||