summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py3
-rw-r--r--models/lora.py59
2 files changed, 38 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 840f8ae..4444cf9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
75 75
76 def persist(self): 76 def persist(self):
77 self.token_embedding.eval() 77 self.token_embedding.persist()
78 self.token_embedding.merged = False
79 78
80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 79 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
81 if isinstance(input_ids, list): 80 if isinstance(input_ids, list):
diff --git a/models/lora.py b/models/lora.py
index 89c4b2e..b7fa58f 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -46,8 +46,8 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
46 self.trainable_ids -= 1 46 self.trainable_ids -= 1
47 47
48 if r > 0: 48 if r > 0:
49 self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) 49 self.lora_A = nn.ParameterList()
50 self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) 50 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
51 self.scaling = self.lora_alpha / self.r 51 self.scaling = self.lora_alpha / self.r
52 self.weight.requires_grad = False 52 self.weight.requires_grad = False
53 53
@@ -83,49 +83,64 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n1 = self.lora_A.shape[1] 86 n1 = len(self.lora_A)
87 n2 = n1 + new_ids.shape[0] 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids:
90 self.lora_A.append(self.weight.new_zeros(self.r))
91
92 def persist(self):
93 if self.r > 0:
94 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
95 if weights is not None:
96 self.weight[mask].data += weights
97 self.trainable_ids[:] = -1
98 self.lora_A = nn.ParameterList()
99
100 def get_weights(self, input_ids: torch.Tensor):
101 trainable_ids = self.trainable_ids[input_ids]
102 mask = ~(trainable_ids == -1)
103 trainable_ids = trainable_ids[mask]
104
105 elems = [self.lora_A[id] for id in trainable_ids]
106
107 if len(elems) == 0:
108 return None, mask
109
110 weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
89 111
90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) 112 return weights, mask
91 self.lora_A = lora_A
92 113
93 def reset_parameters(self): 114 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self) 115 nn.Embedding.reset_parameters(self)
95 if hasattr(self, 'lora_A'): 116 if hasattr(self, 'lora_A'):
96 nn.init.zeros_(self.lora_A) 117 self.lora_A = nn.ParameterList()
97 nn.init.normal_(self.lora_B) 118 nn.init.zeros_(self.lora_B.weight)
98 119
99 def train(self, mode: bool = True): 120 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode) 121 nn.Embedding.train(self, mode)
101 if self.merge_weights and self.merged: 122 if self.merge_weights and self.merged:
102 if self.r > 0: 123 if self.r > 0:
103 mask = ~(self.trainable_ids == -1) 124 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
104 trainable_ids = self.trainable_ids[mask] 125 if weights is not None:
105 self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling 126 self.weight[mask].data -= weights
106 self.merged = False 127 self.merged = False
107 128
108 def eval(self): 129 def eval(self):
109 nn.Embedding.eval(self) 130 nn.Embedding.eval(self)
110 if self.merge_weights and not self.merged: 131 if self.merge_weights and not self.merged:
111 if self.r > 0: 132 if self.r > 0:
112 mask = ~(self.trainable_ids == -1) 133 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
113 trainable_ids = self.trainable_ids[mask] 134 if weights is not None:
114 self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling 135 self.weight[mask].data += weights
115 self.merged = True 136 self.merged = True
116 137
117 def forward(self, input_ids: torch.Tensor): 138 def forward(self, input_ids: torch.Tensor):
118 result = nn.Embedding.forward(self, input_ids) 139 result = nn.Embedding.forward(self, input_ids)
119 140
120 if self.r > 0 and not self.merged: 141 if self.r > 0 and not self.merged:
121 trainable_ids = self.trainable_ids[input_ids] 142 weights, mask = self.get_weights(input_ids)
122 mask = ~(trainable_ids == -1) 143 if weights is not None:
123 trainable_ids = trainable_ids[mask] 144 result[mask] += weights
124
125 after_A = F.embedding(
126 trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
127 self.norm_type, self.scale_grad_by_freq, self.sparse
128 )
129 result[mask] += (after_A @ self.lora_B.T) * self.scaling
130 145
131 return result 146 return result