summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 11:58:47 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 11:58:47 +0200
commit1c63552a20f34bccd461ac0dfa46405f853cbc7c (patch)
treedf26b48ff4c2ef79349b0a4025cdde05b0ed8518 /models
parentFix TI (diff)
downloadtextual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.gz
textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.bz2
textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.zip
Fix TI
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py34
1 files changed, 9 insertions, 25 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2b315c4..2d60c28 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,24 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.num_permanent_embeddings = self.token_embedding.num_embeddings
42 self.init_temp_embeddings()
43 41
44 def init_temp_embeddings(self):
45 self.temp_token_embedding = nn.Embedding( 42 self.temp_token_embedding = nn.Embedding(
46 0, 43 self.token_embedding.num_embeddings,
47 self.token_embedding.embedding_dim, 44 self.token_embedding.embedding_dim,
48 device=self.token_embedding.weight.device, 45 device=self.token_embedding.weight.device,
49 dtype=self.token_embedding.weight.dtype 46 dtype=self.token_embedding.weight.dtype
50 ) 47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
51 self.temp_token_ids = torch.tensor([], dtype=torch.long) 49 self.temp_token_ids = torch.tensor([], dtype=torch.long)
52 50
53 def resize(self, size: int): 51 def resize(self, size: int):
54 self.temp_token_embedding = resize_embedding( 52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
55 self.temp_token_embedding,
56 size - self.num_permanent_embeddings,
57 self.initializer_factor
58 )
59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
60 54
61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -75,15 +69,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
75 initializer = self.get_embed(initializer) 69 initializer = self.get_embed(initializer)
76 70
77 initializer = initializer.to( 71 initializer = initializer.to(
78 device=self.token_embedding.weight.device, 72 device=self.temp_token_embedding.weight.device,
79 dtype=self.token_embedding.weight.dtype, 73 dtype=self.temp_token_embedding.weight.dtype,
80 ) 74 )
81 75
82 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
83 77
84 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
85 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) 79 self.temp_token_embedding.weight.data[token_ids] = initializer
86 self.temp_token_embedding.weight.data[mask] = initializer
87 80
88 def load_embed(self, input_ids: list[int], filename: Path): 81 def load_embed(self, input_ids: list[int], filename: Path):
89 with safe_open(filename, framework="pt", device="cpu") as file: 82 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -94,25 +87,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
94 87
95 def persist(self): 88 def persist(self):
96 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 89 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
97 self.num_permanent_embeddings = self.token_embedding.num_embeddings 90 self.temp_token_ids = torch.tensor([], dtype=torch.long)
98 self.init_temp_embeddings()
99 91
100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 92 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
101 if isinstance(input_ids, list): 93 if isinstance(input_ids, list):
102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 94 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
103 95
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
106 embeds = self.token_embedding(input_ids) 96 embeds = self.token_embedding(input_ids)
107 97
108 embeds_mask = torch.isin(input_ids, all_temp_token_ids) 98 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
109 temp_token_ids = input_ids[embeds_mask] 99 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
116 100
117 return embeds 101 return embeds
118 102