summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
commit5c115a212e40ff177c734351601f9babe29419ce (patch)
treea66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /models
parentFix LR finder (diff)
downloadtextual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip
Added EMA to TI
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index fb639f1..384c795 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
88 def save_embed(self, input_ids: list[int], filename: Path): 88 def save_embed(self, input_ids: list[int], filename: Path):
89 save_file({"embed": self.get_embed(input_ids)}, filename) 89 save_file({"embed": self.get_embed(input_ids)}, filename)
90 90
91 def make_permanent(self): 91 def persist(self):
92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
93 self.temp_token_ids = torch.tensor([], dtype=torch.long) 93 self.temp_token_ids = torch.tensor([], dtype=torch.long)
94 94
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 if isinstance(input_ids, list): 96 if isinstance(input_ids, list):
97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
98 98
99 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
100
101 embeds = self.token_embedding(input_ids) 99 embeds = self.token_embedding(input_ids)
100
101 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
102 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 102 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
103 103
104 return embeds 104 return embeds