summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py41
-rw-r--r--models/lora.py77
-rw-r--r--models/sparse.py110
3 files changed, 157 insertions, 71 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index d02ccc3..8aaea8f 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -10,23 +10,21 @@ from transformers import CLIPTextModel
10from transformers.models.clip import CLIPTextConfig 10from transformers.models.clip import CLIPTextConfig
11from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 11from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
12 12
13from models.lora import LoraEmbedding 13from models.sparse import SparseEmbedding
14 14
15 15
16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): 17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0):
18 super().__init__(config) 18 super().__init__(config)
19 19
20 self.position_embedding = embeddings.position_embedding 20 self.position_embedding = embeddings.position_embedding
21 self.initializer_factor = config.initializer_factor 21 self.initializer_factor = config.initializer_factor
22 self.token_embedding = LoraEmbedding( 22 self.token_embedding = SparseEmbedding(
23 self.token_embedding.num_embeddings, 23 self.token_embedding.num_embeddings,
24 self.token_embedding.embedding_dim, 24 self.token_embedding.embedding_dim,
25 r, 25 alpha,
26 lora_alpha, 26 dropout,
27 lora_dropout,
28 ) 27 )
29
30 self.token_embedding.weight = embeddings.token_embedding.weight 28 self.token_embedding.weight = embeddings.token_embedding.weight
31 29
32 def resize(self, size: int): 30 def resize(self, size: int):
@@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
82 80
83 return self.token_embedding(input_ids) 81 return self.token_embedding(input_ids)
84 82
85 def forward(
86 self,
87 input_ids: Optional[torch.LongTensor] = None,
88 position_ids: Optional[torch.LongTensor] = None,
89 inputs_embeds: Optional[torch.FloatTensor] = None,
90 ) -> torch.Tensor:
91 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
92
93 if position_ids is None:
94 position_ids = self.position_ids[:, :seq_length]
95
96 if inputs_embeds is None:
97 inputs_embeds = self.get_embed(input_ids)
98
99 position_embeddings = self.position_embedding(position_ids)
100 embeddings = inputs_embeds + position_embeddings
101
102 return embeddings
103
104 83
105def patch_managed_embeddings( 84def patch_managed_embeddings(
106 text_encoder: CLIPTextModel, 85 text_encoder: CLIPTextModel,
107 r: int = 8, 86 alpha: int = 8,
108 lora_alpha: int = 8, 87 dropout: float = 0.0
109 lora_dropout: float = 0.0
110) -> ManagedCLIPTextEmbeddings: 88) -> ManagedCLIPTextEmbeddings:
111 text_embeddings = ManagedCLIPTextEmbeddings( 89 text_embeddings = ManagedCLIPTextEmbeddings(
112 text_encoder.config, 90 text_encoder.config,
113 text_encoder.text_model.embeddings, 91 text_encoder.text_model.embeddings,
114 r, 92 alpha,
115 lora_alpha, 93 dropout
116 lora_dropout
117 ) 94 )
118 text_encoder.text_model.embeddings = text_embeddings 95 text_encoder.text_model.embeddings = text_embeddings
119 return text_embeddings 96 return text_embeddings
diff --git a/models/lora.py b/models/lora.py
index 01a540b..e506cff 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -1,8 +1,8 @@
1from typing import Optional 1from typing import Optional
2import math
2 3
3import torch 4import torch
4import torch.nn as nn 5import torch.nn as nn
5import torch.nn.functional as F
6 6
7 7
8class LoraLayer(): 8class LoraLayer():
@@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights 42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 ) 43 )
44 44
45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) 45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
46 self.trainable_ids -= 1
47 46
48 if r > 0: 47 self.lora_A = nn.ParameterList()
49 self.lora_A = nn.ParameterList() 48 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
50 self.lora_B = nn.Linear(r, embedding_dim, bias=False) 49 self.scaling = self.lora_alpha / self.r
51 self.scaling = self.lora_alpha / self.r 50 self.weight.requires_grad = False
52 self.weight.requires_grad = False
53 51
54 self.reset_parameters() 52 self.reset_parameters()
55 53
@@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
70 else: 68 else:
71 nn.init.zeros_(new_emb.weight.data) 69 nn.init.zeros_(new_emb.weight.data)
72 new_emb.weight.data[:n, :] = self.weight.data[:n, :] 70 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
73 new_emb.lora_A = self.lora_A 71 for param in self.lora_A:
74 new_emb.lora_B = self.lora_B 72 new_emb.lora_A.append(param)
73 new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
75 new_emb.trainable_ids[:n] = self.trainable_ids[:n] 74 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
76 75
77 return new_emb 76 return new_emb
@@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
87 n2 = n1 + new_ids.shape[0] 86 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 87 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids: 88 for _ in new_ids:
90 self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) 89 w = self.weight.new_zeros(self.r)
90 self.lora_A.append(w)
91
92 if len(self.lora_A) > 1:
93 elems = torch.stack([param for param in self.lora_A])
94 nn.init.kaiming_uniform_(elems, a=math.sqrt(5))
91 95
92 def get_weights(self, input_ids: torch.Tensor): 96 def get_weights(self, input_ids: torch.Tensor):
93 if len(input_ids.shape) != 1: 97 if len(input_ids.shape) != 1:
94 return torch.stack([self.get_weights(batch) for batch in input_ids]) 98 return torch.stack([self.get_weights(batch) for batch in input_ids])
95 99
96 trainable_ids = self.trainable_ids[input_ids]
97 mask = ~(trainable_ids == -1)
98 trainable_ids = trainable_ids[mask]
99
100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) 100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
101 elems = [self.lora_A[id] for id in trainable_ids]
102 101
103 if len(elems) != 0: 102 if not self.merged:
104 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling 103 trainable_ids = self.trainable_ids[input_ids]
105 weights[mask] = w.to(dtype=weights.dtype) 104 mask = ~(trainable_ids == -1)
105 elems = [self.lora_A[id] for id in trainable_ids[mask]]
106
107 if len(elems) != 0:
108 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
109 weights[mask] = w.to(dtype=weights.dtype)
106 110
107 return weights 111 return weights
108 112
109 def persist(self): 113 def persist(self):
110 if self.r > 0: 114 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
111 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 115 self.trainable_ids[:] = -1
112 self.weight.data += weights 116 self.lora_A = nn.ParameterList()
113 self.trainable_ids[:] = -1 117 nn.init.zeros_(self.lora_B.weight)
114 self.lora_A = nn.ParameterList()
115 118
116 def reset_parameters(self): 119 def reset_parameters(self):
117 nn.Embedding.reset_parameters(self) 120 nn.Embedding.reset_parameters(self)
118 if hasattr(self, 'lora_A'): 121 if hasattr(self, "lora_A"):
119 self.trainable_ids[:] = -1 122 self.trainable_ids[:] = -1
120 self.lora_A = nn.ParameterList() 123 self.lora_A = nn.ParameterList()
121 nn.init.zeros_(self.lora_B.weight) 124 nn.init.zeros_(self.lora_B.weight)
122 125
123 def train(self, mode: bool = True): 126 def train(self, mode: bool = True):
124 nn.Embedding.train(self, mode) 127 nn.Embedding.train(self, mode)
125 if self.merge_weights and self.merged: 128 self.lora_A.train(mode)
126 if self.r > 0: 129 self.lora_B.train(mode)
127 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 130 if not mode and self.merge_weights and not self.merged:
128 self.weight.data -= weights 131 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
132 self.merged = True
133 elif self.merge_weights and self.merged:
134 self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
129 self.merged = False 135 self.merged = False
130 136
131 def eval(self): 137 def eval(self):
132 nn.Embedding.eval(self) 138 nn.Embedding.eval(self)
133 if self.merge_weights and not self.merged: 139 self.lora_A.eval()
134 if self.r > 0: 140 self.lora_B.eval()
135 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
136 self.weight.data += weights
137 self.merged = True
138 141
139 def forward(self, input_ids: torch.LongTensor): 142 def forward(self, input_ids: torch.LongTensor):
140 result = nn.Embedding.forward(self, input_ids) 143 result = nn.Embedding.forward(self, input_ids)
141 144 result += self.get_weights(input_ids)
142 if self.r > 0 and not self.merged:
143 weights = self.get_weights(input_ids)
144 result += weights
145
146 return result 145 return result
diff --git a/models/sparse.py b/models/sparse.py
new file mode 100644
index 0000000..bd45696
--- /dev/null
+++ b/models/sparse.py
@@ -0,0 +1,110 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class SparseEmbedding(nn.Embedding):
8 def __init__(
9 self,
10 num_embeddings: int,
11 embedding_dim: int,
12 alpha: int = 1,
13 dropout: float = 0.0,
14 **kwargs
15 ):
16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
17
18 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
19
20 self.trainable = nn.ParameterList()
21 self.scaling = alpha
22 self.dropout_p = dropout
23 self.weight.requires_grad = False
24
25 if dropout > 0.:
26 self.dropout = nn.Dropout(p=dropout)
27 else:
28 self.dropout = nn.Identity()
29
30 self.reset_parameters()
31
32 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
33 n = min(self.num_embeddings, new_num_embeddings)
34
35 new_emb = SparseEmbedding(
36 new_num_embeddings,
37 self.embedding_dim,
38 self.scaling,
39 self.dropout_p,
40 device=self.weight.device,
41 dtype=self.weight.dtype
42 )
43 if initializer_factor is not None:
44 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
45 else:
46 nn.init.zeros_(new_emb.weight.data)
47 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
48 for param in self.trainable:
49 new_emb.trainable.append(param)
50 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
51
52 return new_emb
53
54 def mark_trainable(self, input_ids: torch.LongTensor):
55 trainable_ids = self.trainable_ids[input_ids]
56 new_ids = input_ids[trainable_ids == -1]
57
58 if new_ids.shape[0] == 0:
59 return
60
61 n1 = len(self.trainable)
62 n2 = n1 + new_ids.shape[0]
63 self.trainable_ids[new_ids] = torch.arange(n1, n2)
64 for _ in new_ids:
65 self.trainable.append(self.weight.new_zeros(self.embedding_dim))
66
67 def get_weights(self, input_ids: torch.Tensor):
68 original_shape = input_ids.shape
69
70 if len(input_ids.shape) != 1:
71 input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1])
72
73 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
74
75 trainable_ids = self.trainable_ids[input_ids]
76 mask = ~(trainable_ids == -1)
77 elems = [self.trainable[id] for id in trainable_ids[mask]]
78
79 if len(elems) != 0:
80 w = self.dropout(torch.stack(elems)) * self.scaling
81 weights[mask] = w.to(dtype=weights.dtype)
82
83 if len(original_shape) != 1:
84 weights = weights.view(original_shape[0], original_shape[1], -1)
85
86 return weights
87
88 def persist(self):
89 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
90 self.trainable_ids[:] = -1
91 self.trainable = nn.ParameterList()
92
93 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self)
95 if hasattr(self, "trainable"):
96 self.trainable_ids[:] = -1
97 self.trainable = nn.ParameterList()
98
99 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode)
101 self.trainable.train(mode)
102
103 def eval(self):
104 nn.Embedding.eval(self)
105 self.trainable.eval()
106
107 def forward(self, input_ids: torch.LongTensor):
108 result = nn.Embedding.forward(self, input_ids)
109 result += self.get_weights(input_ids)
110 return result