diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index cab1515..f90e7c2 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -37,6 +37,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
37 | 37 | ||
38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
40 | self.initializer_factor = config.initializer_factor | ||
40 | 41 | ||
41 | self.temp_token_embedding = nn.Embedding( | 42 | self.temp_token_embedding = nn.Embedding( |
42 | self.token_embedding.num_embeddings, | 43 | self.token_embedding.num_embeddings, |
@@ -44,12 +45,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
44 | device=self.token_embedding.weight.device, | 45 | device=self.token_embedding.weight.device, |
45 | dtype=self.token_embedding.weight.dtype | 46 | dtype=self.token_embedding.weight.dtype |
46 | ) | 47 | ) |
47 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) | 48 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) |
48 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
49 | 50 | ||
50 | def resize(self, size: int): | 51 | def resize(self, size: int): |
51 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
52 | self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
53 | 54 | ||
54 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
55 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
@@ -63,14 +64,15 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
63 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 64 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
64 | 65 | ||
65 | with torch.no_grad(): | 66 | with torch.no_grad(): |
66 | initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) | 67 | initializer = self.get_embed(initializer) |
67 | 68 | ||
68 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 69 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
69 | 70 | ||
70 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 71 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
71 | 72 | ||
72 | if initializer is not None: | 73 | if initializer is not None: |
73 | self.temp_token_embedding.weight.data[token_ids] = initializer | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
75 | dtype=self.temp_token_embedding.weight.dtype) | ||
74 | 76 | ||
75 | def load_embed(self, input_ids: list[int], filename: Path): | 77 | def load_embed(self, input_ids: list[int], filename: Path): |
76 | with safe_open(filename, framework="pt", device="cpu") as file: | 78 | with safe_open(filename, framework="pt", device="cpu") as file: |