From b2db2b6a7c147cdc2901ece92f3918e5b3c47114 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 23:35:11 +0100 Subject: Fix --- models/clip/embeddings.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index cab1515..f90e7c2 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -37,6 +37,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding + self.initializer_factor = config.initializer_factor self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, @@ -44,12 +45,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) - self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) + self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) - self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) + self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): if isinstance(token_ids, int): @@ -63,14 +64,15 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = (initializer * len(token_ids))[:len(token_ids)] with torch.no_grad(): - initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) + initializer = self.get_embed(initializer) token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) if initializer is not None: - self.temp_token_embedding.weight.data[token_ids] = initializer + self.temp_token_embedding.weight.data[token_ids] = initializer.to( + dtype=self.temp_token_embedding.weight.dtype) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: -- cgit v1.2.3-70-g09d2