diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
commit | 5c115a212e40ff177c734351601f9babe29419ce (patch) | |
tree | a66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /models/clip | |
parent | Fix LR finder (diff) | |
download | textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2 textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip |
Added EMA to TI
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
88 | def save_embed(self, input_ids: list[int], filename: Path): | 88 | def save_embed(self, input_ids: list[int], filename: Path): |
89 | save_file({"embed": self.get_embed(input_ids)}, filename) | 89 | save_file({"embed": self.get_embed(input_ids)}, filename) |
90 | 90 | ||
91 | def make_permanent(self): | 91 | def persist(self): |
92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
94 | 94 | ||
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
96 | if isinstance(input_ids, list): | 96 | if isinstance(input_ids, list): |
97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
98 | 98 | ||
99 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
100 | |||
101 | embeds = self.token_embedding(input_ids) | 99 | embeds = self.token_embedding(input_ids) |
100 | |||
101 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
103 | 103 | ||
104 | return embeds | 104 | return embeds |