summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 17:33:00 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 17:33:00 +0200
commit86e908656bcd7585ec45cd930176800f759f146a (patch)
tree1169e9b1728e4c6fc8b70e46a37080ae0794ada8 /models/clip/embeddings.py
parentExperimental: TI via LoRA (diff)
downloadtextual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.gz
textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.bz2
textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.zip
Combined TI with embedding and LoRA
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 88e0cc0..c9c788c 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
66 self.initializer_factor = config.initializer_factor 66 self.initializer_factor = config.initializer_factor
67 67
68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) 68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
69 self.temp_token_embedding = nn.Embedding(
70 self.token_embedding.num_embeddings,
71 self.token_embedding.embedding_dim,
72 device=self.token_embedding.weight.device,
73 dtype=self.token_embedding.weight.dtype
74 )
75 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
69 self.temp_token_ids = torch.tensor([], dtype=torch.long) 76 self.temp_token_ids = torch.tensor([], dtype=torch.long)
70 77
71 def reset_overlay(self): 78 def reset_overlay(self):
72 self.overlay.reset() 79 self.overlay.reset()
73 80
74 def resize(self, size: int): 81 def resize(self, size: int):
82 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
75 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 83 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
76 84
77 def add_embed( 85 def add_embed(
@@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
106 token_ids = torch.tensor(token_ids, dtype=torch.long) 114 token_ids = torch.tensor(token_ids, dtype=torch.long)
107 115
108 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 116 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
117 self.temp_token_embedding.weight.data[token_ids] = initializer
109 self.token_embedding.weight.data[token_ids] = initializer 118 self.token_embedding.weight.data[token_ids] = initializer
110 119
111 def load_embed(self, input_ids: list[int], filename: Path): 120 def load_embed(self, input_ids: list[int], filename: Path):
@@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
116 save_file({"embed": self.get_embed(input_ids)}, filename) 125 save_file({"embed": self.get_embed(input_ids)}, filename)
117 126
118 def persist(self): 127 def persist(self):
119 self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( 128 embeds = self.temp_token_embedding.weight.data[self.temp_token_ids]
120 self.token_embedding.weight.data[self.temp_token_ids] 129 self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds)
121 )
122 self.overlay.reset() 130 self.overlay.reset()
123 self.temp_token_ids = torch.tensor([], dtype=torch.long) 131 self.temp_token_ids = torch.tensor([], dtype=torch.long)
124 132
@@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
127 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 135 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
128 136
129 embeds = self.token_embedding(input_ids) 137 embeds = self.token_embedding(input_ids)
138
130 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 139 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
131 embeds[mask] += self.overlay(embeds[mask]) 140
141 temp_embeds = self.temp_token_embedding(input_ids[mask])
142 embeds[mask] = temp_embeds + self.overlay(temp_embeds)
132 143
133 return embeds 144 return embeds
134 145