diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 41 |
1 files changed, 9 insertions, 32 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -10,23 +10,21 @@ from transformers import CLIPTextModel | |||
| 10 | from transformers.models.clip import CLIPTextConfig | 10 | from transformers.models.clip import CLIPTextConfig |
| 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
| 12 | 12 | ||
| 13 | from models.lora import LoraEmbedding | 13 | from models.sparse import SparseEmbedding |
| 14 | 14 | ||
| 15 | 15 | ||
| 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): | 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): |
| 18 | super().__init__(config) | 18 | super().__init__(config) |
| 19 | 19 | ||
| 20 | self.position_embedding = embeddings.position_embedding | 20 | self.position_embedding = embeddings.position_embedding |
| 21 | self.initializer_factor = config.initializer_factor | 21 | self.initializer_factor = config.initializer_factor |
| 22 | self.token_embedding = LoraEmbedding( | 22 | self.token_embedding = SparseEmbedding( |
| 23 | self.token_embedding.num_embeddings, | 23 | self.token_embedding.num_embeddings, |
| 24 | self.token_embedding.embedding_dim, | 24 | self.token_embedding.embedding_dim, |
| 25 | r, | 25 | alpha, |
| 26 | lora_alpha, | 26 | dropout, |
| 27 | lora_dropout, | ||
| 28 | ) | 27 | ) |
| 29 | |||
| 30 | self.token_embedding.weight = embeddings.token_embedding.weight | 28 | self.token_embedding.weight = embeddings.token_embedding.weight |
| 31 | 29 | ||
| 32 | def resize(self, size: int): | 30 | def resize(self, size: int): |
| @@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 82 | 80 | ||
| 83 | return self.token_embedding(input_ids) | 81 | return self.token_embedding(input_ids) |
| 84 | 82 | ||
| 85 | def forward( | ||
| 86 | self, | ||
| 87 | input_ids: Optional[torch.LongTensor] = None, | ||
| 88 | position_ids: Optional[torch.LongTensor] = None, | ||
| 89 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 90 | ) -> torch.Tensor: | ||
| 91 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 92 | |||
| 93 | if position_ids is None: | ||
| 94 | position_ids = self.position_ids[:, :seq_length] | ||
| 95 | |||
| 96 | if inputs_embeds is None: | ||
| 97 | inputs_embeds = self.get_embed(input_ids) | ||
| 98 | |||
| 99 | position_embeddings = self.position_embedding(position_ids) | ||
| 100 | embeddings = inputs_embeds + position_embeddings | ||
| 101 | |||
| 102 | return embeddings | ||
| 103 | |||
| 104 | 83 | ||
| 105 | def patch_managed_embeddings( | 84 | def patch_managed_embeddings( |
| 106 | text_encoder: CLIPTextModel, | 85 | text_encoder: CLIPTextModel, |
| 107 | r: int = 8, | 86 | alpha: int = 8, |
| 108 | lora_alpha: int = 8, | 87 | dropout: float = 0.0 |
| 109 | lora_dropout: float = 0.0 | ||
| 110 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
| 111 | text_embeddings = ManagedCLIPTextEmbeddings( | 89 | text_embeddings = ManagedCLIPTextEmbeddings( |
| 112 | text_encoder.config, | 90 | text_encoder.config, |
| 113 | text_encoder.text_model.embeddings, | 91 | text_encoder.text_model.embeddings, |
| 114 | r, | 92 | alpha, |
| 115 | lora_alpha, | 93 | dropout |
| 116 | lora_dropout | ||
| 117 | ) | 94 | ) |
| 118 | text_encoder.text_model.embeddings = text_embeddings | 95 | text_encoder.text_model.embeddings = text_embeddings |
| 119 | return text_embeddings | 96 | return text_embeddings |
