diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 76 |
1 files changed, 26 insertions, 50 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel | |||
11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | from models.sparse import PseudoSparseEmbedding | 14 | from models.lora import LoraEmbedding |
15 | |||
16 | |||
17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | ||
18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | ||
19 | |||
20 | if old_num_embeddings == new_num_embeddings: | ||
21 | return old_embedding | ||
22 | |||
23 | n = min(old_num_embeddings, new_num_embeddings) | ||
24 | |||
25 | new_embedding = nn.Embedding( | ||
26 | new_num_embeddings, | ||
27 | old_embedding_dim, | ||
28 | device=old_embedding.weight.device, | ||
29 | dtype=old_embedding.weight.dtype | ||
30 | ) | ||
31 | if initializer_factor is not None: | ||
32 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
33 | else: | ||
34 | nn.init.zeros_(new_embedding.weight.data) | ||
35 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | ||
36 | return new_embedding | ||
37 | 15 | ||
38 | 16 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 17 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): | 18 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): |
41 | super().__init__(config) | 19 | super().__init__(config) |
42 | 20 | ||
43 | self.token_embedding = embeddings.token_embedding | ||
44 | self.position_embedding = embeddings.position_embedding | 21 | self.position_embedding = embeddings.position_embedding |
45 | self.initializer_factor = config.initializer_factor | 22 | self.initializer_factor = config.initializer_factor |
46 | 23 | self.token_embedding = LoraEmbedding( | |
47 | self.token_override_embedding = PseudoSparseEmbedding( | 24 | self.token_embedding.num_embeddings, |
48 | self.token_embedding.embedding_dim, | 25 | self.token_embedding.embedding_dim, |
49 | dropout_p=dropout_p, | 26 | r, |
50 | device=self.token_embedding.weight.device, | 27 | lora_alpha, |
51 | dtype=self.token_embedding.weight.dtype, | 28 | lora_dropout, |
52 | ) | 29 | ) |
53 | 30 | ||
31 | self.token_embedding.weight = embeddings.token_embedding.weight | ||
32 | |||
54 | def resize(self, size: int): | 33 | def resize(self, size: int): |
55 | self.token_override_embedding.resize(size) | 34 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) |
56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | ||
57 | 35 | ||
58 | def add_embed( | 36 | def add_embed( |
59 | self, | 37 | self, |
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 66 | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids, initializer) | ||
91 | 68 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
98 | 75 | ||
99 | def persist(self): | 76 | def persist(self): |
100 | input_ids = torch.arange( | 77 | self.token_embedding.eval() |
101 | self.token_embedding.num_embeddings, | 78 | self.token_embedding.merged = False |
102 | device=self.token_override_embedding.mapping.device | ||
103 | ) | ||
104 | embs, mask = self.token_override_embedding(input_ids) | ||
105 | if embs is not None: | ||
106 | input_ids = input_ids[mask] | ||
107 | self.token_embedding.weight.data[input_ids] = embs | ||
108 | self.token_override_embedding.unset(input_ids) | ||
109 | 79 | ||
110 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
111 | if isinstance(input_ids, list): | 81 | if isinstance(input_ids, list): |
112 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 82 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
113 | 83 | ||
114 | embs = self.token_embedding(input_ids) | 84 | return self.token_embedding(input_ids) |
115 | embs_override, mask = self.token_override_embedding(input_ids) | ||
116 | if embs_override is not None: | ||
117 | embs[mask] = embs_override | ||
118 | |||
119 | return embs | ||
120 | 85 | ||
121 | def forward( | 86 | def forward( |
122 | self, | 87 | self, |
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
138 | return embeddings | 103 | return embeddings |
139 | 104 | ||
140 | 105 | ||
141 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: | 106 | def patch_managed_embeddings( |
142 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) | 107 | text_encoder: CLIPTextModel, |
108 | r: int = 8, | ||
109 | lora_alpha: int = 8, | ||
110 | lora_dropout: float = 0.0 | ||
111 | ) -> ManagedCLIPTextEmbeddings: | ||
112 | text_embeddings = ManagedCLIPTextEmbeddings( | ||
113 | text_encoder.config, | ||
114 | text_encoder.text_model.embeddings, | ||
115 | r, | ||
116 | lora_alpha, | ||
117 | lora_dropout | ||
118 | ) | ||
143 | text_encoder.text_model.embeddings = text_embeddings | 119 | text_encoder.text_model.embeddings = text_embeddings |
144 | return text_embeddings | 120 | return text_embeddings |