summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py15
-rw-r--r--models/sparse.py14
2 files changed, 15 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index a356434..63a141f 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
49 device=self.token_embedding.weight.device, 49 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 50 dtype=self.token_embedding.weight.dtype,
51 ) 51 )
52 self.alpha = alpha
53 52
54 def resize(self, size: int): 53 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 54 self.token_override_embedding.resize(size)
@@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 86 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 87
89 self.token_embedding.weight.data[token_ids] = initializer 88 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids) 89 self.token_override_embedding.set(token_ids, initializer)
91 90
92 def load_embed(self, input_ids: list[int], filename: Path): 91 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 92 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
101 embs, mask = self.token_override_embedding(input_ids) 100 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 101 if embs is not None:
103 input_ids = input_ids[mask] 102 input_ids = input_ids[mask]
104 self.token_embedding.weight.data[input_ids] += self.alpha * embs 103 self.token_embedding.weight.data[input_ids] = embs
105 self.token_override_embedding.unset(input_ids) 104 self.token_override_embedding.unset(input_ids)
106 105
107 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 106 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
108 if isinstance(input_ids, list): 107 if isinstance(input_ids, list):
@@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
111 embs = self.token_embedding(input_ids) 110 embs = self.token_embedding(input_ids)
112 embs_override, mask = self.token_override_embedding(input_ids) 111 embs_override, mask = self.token_override_embedding(input_ids)
113 if embs_override is not None: 112 if embs_override is not None:
114 embs[mask] += self.alpha * embs_override 113 embs[mask] = embs_override
115 114
116 return embs 115 return embs
117 116
@@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
135 return embeddings 134 return embeddings
136 135
137 136
138def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: 137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) 138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
140 text_encoder.text_model.embeddings = text_embeddings 139 text_encoder.text_model.embeddings = text_embeddings
141 return text_embeddings 140 return text_embeddings
diff --git a/models/sparse.py b/models/sparse.py
index 0b15454..8910316 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module):
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.mapping = torch.zeros(0, device=device, dtype=torch.long)
15 15
16 def forward(self, input_ids: Optional[torch.LongTensor] = None): 16 def forward(self, input_ids: torch.LongTensor):
17 if input_ids is None:
18 input_ids = torch.arange(self.mapping.shape[0])
19
20 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]
21 mask = ~(ids == -1) 18 mask = ~(ids == -1)
22 19
@@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module):
43 else: 40 else:
44 return [self.set(id) for id in input_ids] 41 return [self.set(id) for id in input_ids]
45 42
43 if tensor is None:
44 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
45
46 if tensor.shape[-1] != self.embedding_dim:
47 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
48
46 id = self.mapping[input_ids] 49 id = self.mapping[input_ids]
47 50
48 if id == -1: 51 if id == -1:
@@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module):
50 self.mapping[input_ids] = id 53 self.mapping[input_ids] = id
51 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) 54 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
52 55
53 self.params[id] = tensor if tensor is not None else torch.zeros( 56 self.params[id] = tensor
54 self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
55 57
56 def unset(self, input_ids: torch.LongTensor): 58 def unset(self, input_ids: torch.LongTensor):
57 self.mapping[input_ids] = -1 59 self.mapping[input_ids] = -1