diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 76 |
1 files changed, 26 insertions, 50 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -11,49 +11,27 @@ from transformers import CLIPTextModel | |||
| 11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
| 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
| 13 | 13 | ||
| 14 | from models.sparse import PseudoSparseEmbedding | 14 | from models.lora import LoraEmbedding |
| 15 | |||
| 16 | |||
| 17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | ||
| 18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | ||
| 19 | |||
| 20 | if old_num_embeddings == new_num_embeddings: | ||
| 21 | return old_embedding | ||
| 22 | |||
| 23 | n = min(old_num_embeddings, new_num_embeddings) | ||
| 24 | |||
| 25 | new_embedding = nn.Embedding( | ||
| 26 | new_num_embeddings, | ||
| 27 | old_embedding_dim, | ||
| 28 | device=old_embedding.weight.device, | ||
| 29 | dtype=old_embedding.weight.dtype | ||
| 30 | ) | ||
| 31 | if initializer_factor is not None: | ||
| 32 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 33 | else: | ||
| 34 | nn.init.zeros_(new_embedding.weight.data) | ||
| 35 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | ||
| 36 | return new_embedding | ||
| 37 | 15 | ||
| 38 | 16 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 17 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): | 18 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): |
| 41 | super().__init__(config) | 19 | super().__init__(config) |
| 42 | 20 | ||
| 43 | self.token_embedding = embeddings.token_embedding | ||
| 44 | self.position_embedding = embeddings.position_embedding | 21 | self.position_embedding = embeddings.position_embedding |
| 45 | self.initializer_factor = config.initializer_factor | 22 | self.initializer_factor = config.initializer_factor |
| 46 | 23 | self.token_embedding = LoraEmbedding( | |
| 47 | self.token_override_embedding = PseudoSparseEmbedding( | 24 | self.token_embedding.num_embeddings, |
| 48 | self.token_embedding.embedding_dim, | 25 | self.token_embedding.embedding_dim, |
| 49 | dropout_p=dropout_p, | 26 | r, |
| 50 | device=self.token_embedding.weight.device, | 27 | lora_alpha, |
| 51 | dtype=self.token_embedding.weight.dtype, | 28 | lora_dropout, |
| 52 | ) | 29 | ) |
| 53 | 30 | ||
| 31 | self.token_embedding.weight = embeddings.token_embedding.weight | ||
| 32 | |||
| 54 | def resize(self, size: int): | 33 | def resize(self, size: int): |
| 55 | self.token_override_embedding.resize(size) | 34 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) |
| 56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | ||
| 57 | 35 | ||
| 58 | def add_embed( | 36 | def add_embed( |
| 59 | self, | 37 | self, |
| @@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 88 | 66 | ||
| 89 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
| 90 | self.token_override_embedding.set(token_ids, initializer) | ||
| 91 | 68 | ||
| 92 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
| 93 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 98 | 75 | ||
| 99 | def persist(self): | 76 | def persist(self): |
| 100 | input_ids = torch.arange( | 77 | self.token_embedding.eval() |
| 101 | self.token_embedding.num_embeddings, | 78 | self.token_embedding.merged = False |
| 102 | device=self.token_override_embedding.mapping.device | ||
| 103 | ) | ||
| 104 | embs, mask = self.token_override_embedding(input_ids) | ||
| 105 | if embs is not None: | ||
| 106 | input_ids = input_ids[mask] | ||
| 107 | self.token_embedding.weight.data[input_ids] = embs | ||
| 108 | self.token_override_embedding.unset(input_ids) | ||
| 109 | 79 | ||
| 110 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 111 | if isinstance(input_ids, list): | 81 | if isinstance(input_ids, list): |
| 112 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 82 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 113 | 83 | ||
| 114 | embs = self.token_embedding(input_ids) | 84 | return self.token_embedding(input_ids) |
| 115 | embs_override, mask = self.token_override_embedding(input_ids) | ||
| 116 | if embs_override is not None: | ||
| 117 | embs[mask] = embs_override | ||
| 118 | |||
| 119 | return embs | ||
| 120 | 85 | ||
| 121 | def forward( | 86 | def forward( |
| 122 | self, | 87 | self, |
| @@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 138 | return embeddings | 103 | return embeddings |
| 139 | 104 | ||
| 140 | 105 | ||
| 141 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: | 106 | def patch_managed_embeddings( |
| 142 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) | 107 | text_encoder: CLIPTextModel, |
| 108 | r: int = 8, | ||
| 109 | lora_alpha: int = 8, | ||
| 110 | lora_dropout: float = 0.0 | ||
| 111 | ) -> ManagedCLIPTextEmbeddings: | ||
| 112 | text_embeddings = ManagedCLIPTextEmbeddings( | ||
| 113 | text_encoder.config, | ||
| 114 | text_encoder.text_model.embeddings, | ||
| 115 | r, | ||
| 116 | lora_alpha, | ||
| 117 | lora_dropout | ||
| 118 | ) | ||
| 143 | text_encoder.text_model.embeddings = text_embeddings | 119 | text_encoder.text_model.embeddings = text_embeddings |
| 144 | return text_embeddings | 120 | return text_embeddings |
