summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py30
-rw-r--r--training/strategy/ti.py3
2 files changed, 24 insertions, 9 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6be6e9f..8d01867 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,24 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.num_permanent_embeddings = self.token_embedding.num_embeddings
42 self.init_temp_embeddings()
41 43
44 def init_temp_embeddings(self):
42 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.Embedding(
43 self.token_embedding.num_embeddings, 46 0,
44 self.token_embedding.embedding_dim, 47 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device, 48 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 49 dtype=self.token_embedding.weight.dtype
47 ) 50 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 51 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 52
51 def resize(self, size: int): 53 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 54 self.temp_token_embedding = resize_embedding(
55 self.temp_token_embedding,
56 size - self.num_permanent_embeddings,
57 self.initializer_factor
58 )
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 60
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -71,7 +77,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
71 token_ids = torch.tensor(token_ids, dtype=torch.long) 77 token_ids = torch.tensor(token_ids, dtype=torch.long)
72 78
73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 80 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1)
81 self.temp_token_embedding.weight.data[mask] = initializer.to(
75 device=self.temp_token_embedding.weight.device, 82 device=self.temp_token_embedding.weight.device,
76 dtype=self.temp_token_embedding.weight.dtype, 83 dtype=self.temp_token_embedding.weight.dtype,
77 ) 84 )
@@ -85,16 +92,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
85 92
86 def persist(self): 93 def persist(self):
87 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 94 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
88 self.temp_token_ids = torch.tensor([], dtype=torch.long) 95 self.num_permanent_embeddings = self.token_embedding.num_embeddings
96 self.init_temp_embeddings()
89 97
90 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 98 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
91 if isinstance(input_ids, list): 99 if isinstance(input_ids, list):
92 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 100 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
93 101
102 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
103
94 embeds = self.token_embedding(input_ids) 104 embeds = self.token_embedding(input_ids)
95 105
96 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 106 embeds_mask = torch.isin(input_ids, all_temp_token_ids)
97 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 107 temp_token_ids = input_ids[embeds_mask]
108
109 temp_token_ids = temp_token_ids.unsqueeze(1)
110 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
111 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
112
113 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
98 114
99 return embeds 115 return embeds
100 116
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 10bc6d7..b9a5547 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -122,8 +122,7 @@ def textual_inversion_strategy_callbacks(
122 if lambda_ != 0: 122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124 124
125 mask = torch.zeros(w.shape[0], dtype=torch.bool) 125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
127 mask[zero_ids] = False 126 mask[zero_ids] = False
128 127
129 norm = w[mask, :].norm(dim=-1, keepdim=True) 128 norm = w[mask, :].norm(dim=-1, keepdim=True)