diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
66 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
67 | 67 | ||
68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
69 | self.temp_token_embedding = nn.Embedding( | ||
70 | self.token_embedding.num_embeddings, | ||
71 | self.token_embedding.embedding_dim, | ||
72 | device=self.token_embedding.weight.device, | ||
73 | dtype=self.token_embedding.weight.dtype | ||
74 | ) | ||
75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
70 | 77 | ||
71 | def reset_overlay(self): | 78 | def reset_overlay(self): |
72 | self.overlay.reset() | 79 | self.overlay.reset() |
73 | 80 | ||
74 | def resize(self, size: int): | 81 | def resize(self, size: int): |
82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
76 | 84 | ||
77 | def add_embed( | 85 | def add_embed( |
@@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
106 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 114 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
107 | 115 | ||
108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 116 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
117 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
109 | self.token_embedding.weight.data[token_ids] = initializer | 118 | self.token_embedding.weight.data[token_ids] = initializer |
110 | 119 | ||
111 | def load_embed(self, input_ids: list[int], filename: Path): | 120 | def load_embed(self, input_ids: list[int], filename: Path): |
@@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
116 | save_file({"embed": self.get_embed(input_ids)}, filename) | 125 | save_file({"embed": self.get_embed(input_ids)}, filename) |
117 | 126 | ||
118 | def persist(self): | 127 | def persist(self): |
119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( | 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] |
120 | self.token_embedding.weight.data[self.temp_token_ids] | 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) |
121 | ) | ||
122 | self.overlay.reset() | 130 | self.overlay.reset() |
123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
124 | 132 | ||
@@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
128 | 136 | ||
129 | embeds = self.token_embedding(input_ids) | 137 | embeds = self.token_embedding(input_ids) |
138 | |||
130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
131 | embeds[mask] += self.overlay(embeds[mask]) | 140 | |
141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
132 | 143 | ||
133 | return embeds | 144 | return embeds |
134 | 145 | ||