From 68164329b97f5cd79a56372dc6cace4b038afce8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 22:08:21 +0100 Subject: Update --- models/clip/embeddings.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..9c3a56b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): if isinstance(token_ids, int): token_ids = [token_ids] - if initializer is not None: - if isinstance(initializer, int): - initializer = [initializer] + if initializer is None: + initializer = token_ids - if isinstance(initializer, list): - initializer = (initializer * len(token_ids))[:len(token_ids)] + if isinstance(initializer, int): + initializer = [initializer] - with torch.no_grad(): - initializer = self.get_embed(initializer) + if isinstance(initializer, list): + initializer = (initializer * len(token_ids))[:len(token_ids)] + + with torch.no_grad(): + initializer = self.get_embed(initializer) token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - - if initializer is not None: - self.temp_token_embedding.weight.data[token_ids] = initializer.to( - dtype=self.temp_token_embedding.weight.dtype) + self.temp_token_embedding.weight.data[token_ids] = initializer.to( + dtype=self.temp_token_embedding.weight.dtype) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: -- cgit v1.2.3-54-g00ecf