summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py76
1 files changed, 26 insertions, 50 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9be8256..60c1b20 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig 11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14from models.sparse import PseudoSparseEmbedding 14from models.lora import LoraEmbedding
15
16
17def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
18 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
19
20 if old_num_embeddings == new_num_embeddings:
21 return old_embedding
22
23 n = min(old_num_embeddings, new_num_embeddings)
24
25 new_embedding = nn.Embedding(
26 new_num_embeddings,
27 old_embedding_dim,
28 device=old_embedding.weight.device,
29 dtype=old_embedding.weight.dtype
30 )
31 if initializer_factor is not None:
32 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
33 else:
34 nn.init.zeros_(new_embedding.weight.data)
35 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
36 return new_embedding
37 15
38 16
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 17class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): 18 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0):
41 super().__init__(config) 19 super().__init__(config)
42 20
43 self.token_embedding = embeddings.token_embedding
44 self.position_embedding = embeddings.position_embedding 21 self.position_embedding = embeddings.position_embedding
45 self.initializer_factor = config.initializer_factor 22 self.initializer_factor = config.initializer_factor
46 23 self.token_embedding = LoraEmbedding(
47 self.token_override_embedding = PseudoSparseEmbedding( 24 self.token_embedding.num_embeddings,
48 self.token_embedding.embedding_dim, 25 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p, 26 r,
50 device=self.token_embedding.weight.device, 27 lora_alpha,
51 dtype=self.token_embedding.weight.dtype, 28 lora_dropout,
52 ) 29 )
53 30
31 self.token_embedding.weight = embeddings.token_embedding.weight
32
54 def resize(self, size: int): 33 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 34 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor)
56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
57 35
58 def add_embed( 36 def add_embed(
59 self, 37 self,
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 65 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 66
89 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids, initializer)
91 68
92 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 70 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
98 75
99 def persist(self): 76 def persist(self):
100 input_ids = torch.arange( 77 self.token_embedding.eval()
101 self.token_embedding.num_embeddings, 78 self.token_embedding.merged = False
102 device=self.token_override_embedding.mapping.device
103 )
104 embs, mask = self.token_override_embedding(input_ids)
105 if embs is not None:
106 input_ids = input_ids[mask]
107 self.token_embedding.weight.data[input_ids] = embs
108 self.token_override_embedding.unset(input_ids)
109 79
110 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
111 if isinstance(input_ids, list): 81 if isinstance(input_ids, list):
112 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 82 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
113 83
114 embs = self.token_embedding(input_ids) 84 return self.token_embedding(input_ids)
115 embs_override, mask = self.token_override_embedding(input_ids)
116 if embs_override is not None:
117 embs[mask] = embs_override
118
119 return embs
120 85
121 def forward( 86 def forward(
122 self, 87 self,
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
138 return embeddings 103 return embeddings
139 104
140 105
141def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: 106def patch_managed_embeddings(
142 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) 107 text_encoder: CLIPTextModel,
108 r: int = 8,
109 lora_alpha: int = 8,
110 lora_dropout: float = 0.0
111) -> ManagedCLIPTextEmbeddings:
112 text_embeddings = ManagedCLIPTextEmbeddings(
113 text_encoder.config,
114 text_encoder.text_model.embeddings,
115 r,
116 lora_alpha,
117 lora_dropout
118 )
143 text_encoder.text_model.embeddings = text_embeddings 119 text_encoder.text_model.embeddings = text_embeddings
144 return text_embeddings 120 return text_embeddings