diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 5 |
1 files changed, 0 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 384c795..9d8f770 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -53,8 +53,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 54 | ||
| 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| 56 | init_ratio = 1.0 | ||
| 57 | |||
| 58 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
| 59 | token_ids = [token_ids] | 57 | token_ids = [token_ids] |
| 60 | 58 | ||
| @@ -65,7 +63,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 65 | initializer = [initializer] | 63 | initializer = [initializer] |
| 66 | 64 | ||
| 67 | if isinstance(initializer, list): | 65 | if isinstance(initializer, list): |
| 68 | init_ratio = len(initializer) / len(token_ids) | ||
| 69 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 66 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
| 70 | 67 | ||
| 71 | with torch.no_grad(): | 68 | with torch.no_grad(): |
| @@ -79,8 +76,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 79 | dtype=self.temp_token_embedding.weight.dtype, | 76 | dtype=self.temp_token_embedding.weight.dtype, |
| 80 | ) | 77 | ) |
| 81 | 78 | ||
| 82 | return init_ratio | ||
| 83 | |||
| 84 | def load_embed(self, input_ids: list[int], filename: Path): | 79 | def load_embed(self, input_ids: list[int], filename: Path): |
| 85 | with safe_open(filename, framework="pt", device="cpu") as file: | 80 | with safe_open(filename, framework="pt", device="cpu") as file: |
| 86 | self.add_embed(input_ids, file.get_tensor("embed")) | 81 | self.add_embed(input_ids, file.get_tensor("embed")) |
