summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-26 14:29:57 +0200
committerVolpeon <git@volpeon.ink>2023-03-26 14:29:57 +0200
commitb5e0ef7b8a4629c2d1885a96f0faf24fafba1467 (patch)
tree675749d04db22ffca4ca0eb74449c1242c582bc4 /models
parentImproved inverted tokens (diff)
downloadtextual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.tar.gz
textual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.tar.bz2
textual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.zip
Improved TI embeddings
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py30
1 files changed, 23 insertions, 7 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6be6e9f..8d01867 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,24 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.num_permanent_embeddings = self.token_embedding.num_embeddings
42 self.init_temp_embeddings()
41 43
44 def init_temp_embeddings(self):
42 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.Embedding(
43 self.token_embedding.num_embeddings, 46 0,
44 self.token_embedding.embedding_dim, 47 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device, 48 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 49 dtype=self.token_embedding.weight.dtype
47 ) 50 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 51 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 52
51 def resize(self, size: int): 53 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 54 self.temp_token_embedding = resize_embedding(
55 self.temp_token_embedding,
56 size - self.num_permanent_embeddings,
57 self.initializer_factor
58 )
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 60
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -71,7 +77,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
71 token_ids = torch.tensor(token_ids, dtype=torch.long) 77 token_ids = torch.tensor(token_ids, dtype=torch.long)
72 78
73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 80 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1)
81 self.temp_token_embedding.weight.data[mask] = initializer.to(
75 device=self.temp_token_embedding.weight.device, 82 device=self.temp_token_embedding.weight.device,
76 dtype=self.temp_token_embedding.weight.dtype, 83 dtype=self.temp_token_embedding.weight.dtype,
77 ) 84 )
@@ -85,16 +92,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
85 92
86 def persist(self): 93 def persist(self):
87 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 94 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
88 self.temp_token_ids = torch.tensor([], dtype=torch.long) 95 self.num_permanent_embeddings = self.token_embedding.num_embeddings
96 self.init_temp_embeddings()
89 97
90 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 98 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
91 if isinstance(input_ids, list): 99 if isinstance(input_ids, list):
92 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 100 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
93 101
102 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
103
94 embeds = self.token_embedding(input_ids) 104 embeds = self.token_embedding(input_ids)
95 105
96 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 106 embeds_mask = torch.isin(input_ids, all_temp_token_ids)
97 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 107 temp_token_ids = input_ids[embeds_mask]
108
109 temp_token_ids = temp_token_ids.unsqueeze(1)
110 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
111 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
112
113 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
98 114
99 return embeds 115 return embeds
100 116