summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 13:19:05 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 13:19:05 +0200
commitc40170386fd055f715db90886f0ac0da5c575bd9 (patch)
tree063d4d89d179750a241e8e652d77ea8586fd2ac7 /models
parentFix TI (diff)
downloadtextual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.gz
textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.bz2
textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.zip
Fix TI
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py34
1 files changed, 25 insertions, 9 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2d60c28..e8cc865 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.init_temp_embeddings()
41 42
43 def init_temp_embeddings(self):
42 self.temp_token_embedding = nn.Embedding( 44 self.temp_token_embedding = nn.Embedding(
43 self.token_embedding.num_embeddings, 45 0,
44 self.token_embedding.embedding_dim, 46 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device, 47 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 48 dtype=self.token_embedding.weight.dtype
47 ) 49 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 51
51 def resize(self, size: int): 52 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,9 +74,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 ) 74 )
75 75
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
77
78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 77 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
79 self.temp_token_embedding.weight.data[token_ids] = initializer 78
79 self.temp_token_embedding = resize_embedding(
80 self.temp_token_embedding,
81 self.temp_token_ids.shape[0],
82 self.initializer_factor
83 )
84
85 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1)
86 self.temp_token_embedding.weight.data[mask] = initializer
87 self.token_embedding.weight.data[token_ids] = initializer
80 88
81 def load_embed(self, input_ids: list[int], filename: Path): 89 def load_embed(self, input_ids: list[int], filename: Path):
82 with safe_open(filename, framework="pt", device="cpu") as file: 90 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -86,17 +94,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
86 save_file({"embed": self.get_embed(input_ids)}, filename) 94 save_file({"embed": self.get_embed(input_ids)}, filename)
87 95
88 def persist(self): 96 def persist(self):
89 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 97 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:]
90 self.temp_token_ids = torch.tensor([], dtype=torch.long) 98 self.init_temp_embeddings()
91 99
92 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
93 if isinstance(input_ids, list): 101 if isinstance(input_ids, list):
94 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
95 103
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
96 embeds = self.token_embedding(input_ids) 106 embeds = self.token_embedding(input_ids)
97 107
98 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 108 embeds_mask = torch.isin(input_ids, all_temp_token_ids)
99 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 109 temp_token_ids = input_ids[embeds_mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
100 116
101 return embeds 117 return embeds
102 118