From 5c115a212e40ff177c734351601f9babe29419ce Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 22:05:25 +0100 Subject: Added EMA to TI --- models/clip/embeddings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) - def make_permanent(self): + def persist(self): self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] self.temp_token_ids = torch.tensor([], dtype=torch.long) @@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds = self.token_embedding(input_ids) + + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) embeds[mask] = self.temp_token_embedding(input_ids)[mask] return embeds -- cgit v1.2.3-54-g00ecf