summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py40
1 files changed, 23 insertions, 17 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 870ee49..95904cf 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
42 self.init_temp_embeddings() 42 self.init_temp_embeddings()
43 43
44 def init_temp_embeddings(self): 44 def init_temp_embeddings(self):
45 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.ParameterList()
46 0,
47 self.token_embedding.embedding_dim,
48 device=self.token_embedding.weight.device,
49 dtype=self.token_embedding.weight.dtype
50 )
51 self.temp_token_ids = torch.tensor([], dtype=torch.long) 46 self.temp_token_ids = torch.tensor([], dtype=torch.long)
52 47
53 def resize(self, size: int): 48 def resize(self, size: int):
54 self.temp_token_embedding = resize_embedding( 49 for _ in range(len(self.temp_token_embedding), size):
55 self.temp_token_embedding, 50 self.temp_token_embedding.append(torch.zeros(
56 size - self.num_permanent_embeddings, 51 self.token_embedding.embedding_dim,
57 self.initializer_factor 52 device=self.token_embedding.weight.device,
58 ) 53 dtype=self.token_embedding.weight.dtype,
54 ))
59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 55 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
60 56
61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 57 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 with torch.no_grad(): 70 with torch.no_grad():
75 initializer = self.get_embed(initializer) 71 initializer = self.get_embed(initializer)
76 72
73 initializer = initializer.to(
74 device=self.token_embedding.weight.device,
75 dtype=self.token_embedding.weight.dtype,
76 )
77
77 token_ids = torch.tensor(token_ids, dtype=torch.long) 78 token_ids = torch.tensor(token_ids, dtype=torch.long)
78 79
79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 80 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
80 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) 81 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1)
81 self.temp_token_embedding.weight.data[mask] = initializer.to( 82
82 device=self.temp_token_embedding.weight.device, 83 for i, id in enumerate(mask):
83 dtype=self.temp_token_embedding.weight.dtype, 84 self.temp_token_embedding[id] = initializer[i]
84 )
85 85
86 def load_embed(self, input_ids: list[int], filename: Path): 86 def load_embed(self, input_ids: list[int], filename: Path):
87 with safe_open(filename, framework="pt", device="cpu") as file: 87 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
91 save_file({"embed": self.get_embed(input_ids)}, filename) 91 save_file({"embed": self.get_embed(input_ids)}, filename)
92 92
93 def persist(self): 93 def persist(self):
94 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 94 for id, emb in zip(self.temp_token_ids, self.temp_token_embedding):
95 self.token_embedding.weight.data[id] = emb
95 self.num_permanent_embeddings = self.token_embedding.num_embeddings 96 self.num_permanent_embeddings = self.token_embedding.num_embeddings
96 self.init_temp_embeddings() 97 self.init_temp_embeddings()
97 98
@@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
110 all_temp_token_ids = all_temp_token_ids.unsqueeze(0) 111 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
111 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() 112 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
112 113
113 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) 114 if len(temp_token_ids):
115 embeds_override = torch.stack([
116 self.temp_token_embedding[id]
117 for id in temp_token_ids
118 ])
119 embeds[embeds_mask] = embeds_override
114 120
115 return embeds 121 return embeds
116 122