From 3f9c6ed8b0c169d79213784463ffab962ec49419 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 15:32:36 +0200 Subject: Fix --- models/clip/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8aaea8f..2b23bd3 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -62,7 +62,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.mark_trainable(token_ids) - self.token_embedding.weight[token_ids].data = initializer + self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: -- cgit v1.2.3-54-g00ecf