diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..9c3a56b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 56 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
| 57 | token_ids = [token_ids] | 57 | token_ids = [token_ids] |
| 58 | 58 | ||
| 59 | if initializer is not None: | 59 | if initializer is None: |
| 60 | if isinstance(initializer, int): | 60 | initializer = token_ids |
| 61 | initializer = [initializer] | ||
| 62 | 61 | ||
| 63 | if isinstance(initializer, list): | 62 | if isinstance(initializer, int): |
| 64 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 63 | initializer = [initializer] |
| 65 | 64 | ||
| 66 | with torch.no_grad(): | 65 | if isinstance(initializer, list): |
| 67 | initializer = self.get_embed(initializer) | 66 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
| 67 | |||
| 68 | with torch.no_grad(): | ||
| 69 | initializer = self.get_embed(initializer) | ||
| 68 | 70 | ||
| 69 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 71 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 70 | 72 | ||
| 71 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 72 | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | |
| 73 | if initializer is not None: | 75 | dtype=self.temp_token_embedding.weight.dtype) |
| 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | ||
| 75 | dtype=self.temp_token_embedding.weight.dtype) | ||
| 76 | 76 | ||
| 77 | def load_embed(self, input_ids: list[int], filename: Path): | 77 | def load_embed(self, input_ids: list[int], filename: Path): |
| 78 | with safe_open(filename, framework="pt", device="cpu") as file: | 78 | with safe_open(filename, framework="pt", device="cpu") as file: |
