summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py76
-rw-r--r--models/lora.py131
-rw-r--r--models/sparse.py66
3 files changed, 157 insertions, 116 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9be8256..60c1b20 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig 11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14from models.sparse import PseudoSparseEmbedding 14from models.lora import LoraEmbedding
15
16
17def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
18 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
19
20 if old_num_embeddings == new_num_embeddings:
21 return old_embedding
22
23 n = min(old_num_embeddings, new_num_embeddings)
24
25 new_embedding = nn.Embedding(
26 new_num_embeddings,
27 old_embedding_dim,
28 device=old_embedding.weight.device,
29 dtype=old_embedding.weight.dtype
30 )
31 if initializer_factor is not None:
32 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
33 else:
34 nn.init.zeros_(new_embedding.weight.data)
35 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
36 return new_embedding
37 15
38 16
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 17class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): 18 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0):
41 super().__init__(config) 19 super().__init__(config)
42 20
43 self.token_embedding = embeddings.token_embedding
44 self.position_embedding = embeddings.position_embedding 21 self.position_embedding = embeddings.position_embedding
45 self.initializer_factor = config.initializer_factor 22 self.initializer_factor = config.initializer_factor
46 23 self.token_embedding = LoraEmbedding(
47 self.token_override_embedding = PseudoSparseEmbedding( 24 self.token_embedding.num_embeddings,
48 self.token_embedding.embedding_dim, 25 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p, 26 r,
50 device=self.token_embedding.weight.device, 27 lora_alpha,
51 dtype=self.token_embedding.weight.dtype, 28 lora_dropout,
52 ) 29 )
53 30
31 self.token_embedding.weight = embeddings.token_embedding.weight
32
54 def resize(self, size: int): 33 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 34 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor)
56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
57 35
58 def add_embed( 36 def add_embed(
59 self, 37 self,
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 65 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 66
89 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids, initializer)
91 68
92 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 70 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
98 75
99 def persist(self): 76 def persist(self):
100 input_ids = torch.arange( 77 self.token_embedding.eval()
101 self.token_embedding.num_embeddings, 78 self.token_embedding.merged = False
102 device=self.token_override_embedding.mapping.device
103 )
104 embs, mask = self.token_override_embedding(input_ids)
105 if embs is not None:
106 input_ids = input_ids[mask]
107 self.token_embedding.weight.data[input_ids] = embs
108 self.token_override_embedding.unset(input_ids)
109 79
110 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
111 if isinstance(input_ids, list): 81 if isinstance(input_ids, list):
112 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 82 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
113 83
114 embs = self.token_embedding(input_ids) 84 return self.token_embedding(input_ids)
115 embs_override, mask = self.token_override_embedding(input_ids)
116 if embs_override is not None:
117 embs[mask] = embs_override
118
119 return embs
120 85
121 def forward( 86 def forward(
122 self, 87 self,
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
138 return embeddings 103 return embeddings
139 104
140 105
141def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: 106def patch_managed_embeddings(
142 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) 107 text_encoder: CLIPTextModel,
108 r: int = 8,
109 lora_alpha: int = 8,
110 lora_dropout: float = 0.0
111) -> ManagedCLIPTextEmbeddings:
112 text_embeddings = ManagedCLIPTextEmbeddings(
113 text_encoder.config,
114 text_encoder.text_model.embeddings,
115 r,
116 lora_alpha,
117 lora_dropout
118 )
143 text_encoder.text_model.embeddings = text_embeddings 119 text_encoder.text_model.embeddings = text_embeddings
144 return text_embeddings 120 return text_embeddings
diff --git a/models/lora.py b/models/lora.py
new file mode 100644
index 0000000..c0f74a6
--- /dev/null
+++ b/models/lora.py
@@ -0,0 +1,131 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6
7
8class LoraLayer():
9 def __init__(
10 self,
11 r: int,
12 lora_alpha: int,
13 lora_dropout: float,
14 merge_weights: bool,
15 ):
16 self.r = r
17 self.lora_alpha = lora_alpha
18 self.lora_dropout_p = lora_dropout
19
20 if lora_dropout > 0.:
21 self.lora_dropout = nn.Dropout(p=lora_dropout)
22 else:
23 self.lora_dropout = nn.Identity()
24
25 self.merged = False
26 self.merge_weights = merge_weights
27
28
29class LoraEmbedding(nn.Embedding, LoraLayer):
30 def __init__(
31 self,
32 num_embeddings: int,
33 embedding_dim: int,
34 r: int = 0,
35 lora_alpha: int = 1,
36 lora_dropout: float = 0.0,
37 merge_weights: bool = True,
38 **kwargs
39 ):
40 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
41 LoraLayer.__init__(
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 )
44
45 self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long))
46 self.trainable_ids -= 1
47
48 if r > 0:
49 self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0)))
50 self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
51 self.scaling = self.lora_alpha / self.r
52 self.weight.requires_grad = False
53
54 self.reset_parameters()
55
56 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
57 n = min(self.num_embeddings, new_num_embeddings)
58
59 new_emb = LoraEmbedding(
60 new_num_embeddings,
61 self.embedding_dim,
62 self.r,
63 self.lora_alpha,
64 self.lora_dropout_p,
65 device=self.weight.device,
66 dtype=self.weight.dtype
67 )
68 if initializer_factor is not None:
69 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
70 else:
71 nn.init.zeros_(new_emb.weight.data)
72 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
73 new_emb.lora_A = self.lora_A
74 new_emb.lora_B = self.lora_B
75 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
76
77 return new_emb
78
79 def mark_trainable(self, input_ids):
80 trainable_ids = self.trainable_ids[input_ids]
81 new_ids = trainable_ids[trainable_ids == -1]
82
83 if new_ids.shape[0] == 0:
84 return
85
86 n = self.trainable_ids.shape[0]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0])
88
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A
92
93 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self)
95 if hasattr(self, 'lora_A'):
96 nn.init.zeros_(self.lora_A)
97 nn.init.normal_(self.lora_B)
98
99 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode)
101 if self.merge_weights and self.merged:
102 if self.r > 0:
103 mask = ~(self.trainable_ids == -1)
104 trainable_ids = self.trainable_ids[mask]
105 self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling
106 self.merged = False
107
108 def eval(self):
109 nn.Embedding.eval(self)
110 if self.merge_weights and not self.merged:
111 if self.r > 0:
112 mask = ~(self.trainable_ids == -1)
113 trainable_ids = self.trainable_ids[mask]
114 self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling
115 self.merged = True
116
117 def forward(self, input_ids: torch.Tensor):
118 result = nn.Embedding.forward(self, input_ids)
119
120 if self.r > 0 and not self.merged:
121 trainable_ids = self.trainable_ids[input_ids]
122 mask = ~(trainable_ids == -1)
123 trainable_ids = trainable_ids[mask]
124
125 after_A = F.embedding(
126 trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
127 self.norm_type, self.scale_grad_by_freq, self.sparse
128 )
129 result[mask] += (after_A @ self.lora_B.T) * self.scaling
130
131 return result
diff --git a/models/sparse.py b/models/sparse.py
deleted file mode 100644
index 07b3413..0000000
--- a/models/sparse.py
+++ /dev/null
@@ -1,66 +0,0 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__()
10
11 self.embedding_dim = embedding_dim
12 self.dtype = dtype
13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = nn.Identity()
19
20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
21
22 def forward(self, input_ids: torch.LongTensor):
23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
25 mask = ~(ids == -1)
26
27 if torch.all(~mask):
28 embs = None
29 else:
30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
31
32 return embs, mask
33
34 def resize(self, new_num_embeddings: int):
35 old_num_embeddings = self.mapping.shape[0]
36 n = min(old_num_embeddings, new_num_embeddings)
37
38 new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
39 new_mapping[:n] = self.mapping[:n]
40
41 self.mapping = new_mapping
42
43 def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
44 if len(input_ids.shape) != 0:
45 if tensor is not None:
46 return [self.set(id, t) for id, t in zip(input_ids, tensor)]
47 else:
48 return [self.set(id) for id in input_ids]
49
50 if tensor is None:
51 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
52
53 if tensor.shape[-1] != self.embedding_dim:
54 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
55
56 id = self.mapping[input_ids]
57
58 if id == -1:
59 id = len(self.params)
60 self.mapping[input_ids] = id
61 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
62
63 self.params[id] = tensor
64
65 def unset(self, input_ids: torch.LongTensor):
66 self.mapping[input_ids] = -1