diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index cab1515..f90e7c2 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,6 +37,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 37 | 37 | ||
| 38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
| 39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
| 40 | self.initializer_factor = config.initializer_factor | ||
| 40 | 41 | ||
| 41 | self.temp_token_embedding = nn.Embedding( | 42 | self.temp_token_embedding = nn.Embedding( |
| 42 | self.token_embedding.num_embeddings, | 43 | self.token_embedding.num_embeddings, |
| @@ -44,12 +45,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 44 | device=self.token_embedding.weight.device, | 45 | device=self.token_embedding.weight.device, |
| 45 | dtype=self.token_embedding.weight.dtype | 46 | dtype=self.token_embedding.weight.dtype |
| 46 | ) | 47 | ) |
| 47 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) | 48 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) |
| 48 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 49 | 50 | ||
| 50 | def resize(self, size: int): | 51 | def resize(self, size: int): |
| 51 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
| 52 | self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 53 | 54 | ||
| 54 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| 55 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
| @@ -63,14 +64,15 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 63 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 64 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
| 64 | 65 | ||
| 65 | with torch.no_grad(): | 66 | with torch.no_grad(): |
| 66 | initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) | 67 | initializer = self.get_embed(initializer) |
| 67 | 68 | ||
| 68 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 69 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 69 | 70 | ||
| 70 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 71 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 71 | 72 | ||
| 72 | if initializer is not None: | 73 | if initializer is not None: |
| 73 | self.temp_token_embedding.weight.data[token_ids] = initializer | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
| 75 | dtype=self.temp_token_embedding.weight.dtype) | ||
| 74 | 76 | ||
| 75 | def load_embed(self, input_ids: list[int], filename: Path): | 77 | def load_embed(self, input_ids: list[int], filename: Path): |
| 76 | with safe_open(filename, framework="pt", device="cpu") as file: | 78 | with safe_open(filename, framework="pt", device="cpu") as file: |
