summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 10:16:10 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 10:16:10 +0200
commita412196d1a3b616655de52fb12e0d8528e1f1af0 (patch)
treea97dec23f16231b5ca8c1092a009c983b6971880
parentSparse TI embeddings without sparse tensors (diff)
downloadtextual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.tar.gz
textual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.tar.bz2
textual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.zip
Revert to regular embeddings
-rw-r--r--models/clip/embeddings.py34
1 files changed, 15 insertions, 19 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 95904cf..2b315c4 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -42,16 +42,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
42 self.init_temp_embeddings() 42 self.init_temp_embeddings()
43 43
44 def init_temp_embeddings(self): 44 def init_temp_embeddings(self):
45 self.temp_token_embedding = nn.ParameterList() 45 self.temp_token_embedding = nn.Embedding(
46 0,
47 self.token_embedding.embedding_dim,
48 device=self.token_embedding.weight.device,
49 dtype=self.token_embedding.weight.dtype
50 )
46 self.temp_token_ids = torch.tensor([], dtype=torch.long) 51 self.temp_token_ids = torch.tensor([], dtype=torch.long)
47 52
48 def resize(self, size: int): 53 def resize(self, size: int):
49 for _ in range(len(self.temp_token_embedding), size): 54 self.temp_token_embedding = resize_embedding(
50 self.temp_token_embedding.append(torch.zeros( 55 self.temp_token_embedding,
51 self.token_embedding.embedding_dim, 56 size - self.num_permanent_embeddings,
52 device=self.token_embedding.weight.device, 57 self.initializer_factor
53 dtype=self.token_embedding.weight.dtype, 58 )
54 ))
55 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
56 60
57 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -78,10 +82,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
78 token_ids = torch.tensor(token_ids, dtype=torch.long) 82 token_ids = torch.tensor(token_ids, dtype=torch.long)
79 83
80 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 84 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
81 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) 85 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1)
82 86 self.temp_token_embedding.weight.data[mask] = initializer
83 for i, id in enumerate(mask):
84 self.temp_token_embedding[id] = initializer[i]
85 87
86 def load_embed(self, input_ids: list[int], filename: Path): 88 def load_embed(self, input_ids: list[int], filename: Path):
87 with safe_open(filename, framework="pt", device="cpu") as file: 89 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -91,8 +93,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
91 save_file({"embed": self.get_embed(input_ids)}, filename) 93 save_file({"embed": self.get_embed(input_ids)}, filename)
92 94
93 def persist(self): 95 def persist(self):
94 for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): 96 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
95 self.token_embedding.weight.data[id] = emb
96 self.num_permanent_embeddings = self.token_embedding.num_embeddings 97 self.num_permanent_embeddings = self.token_embedding.num_embeddings
97 self.init_temp_embeddings() 98 self.init_temp_embeddings()
98 99
@@ -111,12 +112,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
111 all_temp_token_ids = all_temp_token_ids.unsqueeze(0) 112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
112 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() 113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
113 114
114 if len(temp_token_ids): 115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
115 embeds_override = torch.stack([
116 self.temp_token_embedding[id]
117 for id in temp_token_ids
118 ])
119 embeds[embeds_mask] = embeds_override
120 116
121 return embeds 117 return embeds
122 118