diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 2 | ||||
-rw-r--r-- | models/lora.py | 42 |
2 files changed, 21 insertions, 23 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4444cf9..d02ccc3 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -64,7 +64,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
64 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
65 | 65 | ||
66 | self.token_embedding.mark_trainable(token_ids) | 66 | self.token_embedding.mark_trainable(token_ids) |
67 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight[token_ids].data = initializer |
68 | 68 | ||
69 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
70 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/models/lora.py b/models/lora.py index b7fa58f..a8197a5 100644 --- a/models/lora.py +++ b/models/lora.py | |||
@@ -42,7 +42,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
43 | ) | 43 | ) |
44 | 44 | ||
45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) |
46 | self.trainable_ids -= 1 | 46 | self.trainable_ids -= 1 |
47 | 47 | ||
48 | if r > 0: | 48 | if r > 0: |
@@ -76,7 +76,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
76 | 76 | ||
77 | return new_emb | 77 | return new_emb |
78 | 78 | ||
79 | def mark_trainable(self, input_ids): | 79 | def mark_trainable(self, input_ids: torch.LongTensor): |
80 | trainable_ids = self.trainable_ids[input_ids] | 80 | trainable_ids = self.trainable_ids[input_ids] |
81 | new_ids = input_ids[trainable_ids == -1] | 81 | new_ids = input_ids[trainable_ids == -1] |
82 | 82 | ||
@@ -87,15 +87,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
87 | n2 = n1 + new_ids.shape[0] | 87 | n2 = n1 + new_ids.shape[0] |
88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
89 | for _ in new_ids: | 89 | for _ in new_ids: |
90 | self.lora_A.append(self.weight.new_zeros(self.r)) | 90 | self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) |
91 | |||
92 | def persist(self): | ||
93 | if self.r > 0: | ||
94 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
95 | if weights is not None: | ||
96 | self.weight[mask].data += weights | ||
97 | self.trainable_ids[:] = -1 | ||
98 | self.lora_A = nn.ParameterList() | ||
99 | 91 | ||
100 | def get_weights(self, input_ids: torch.Tensor): | 92 | def get_weights(self, input_ids: torch.Tensor): |
101 | trainable_ids = self.trainable_ids[input_ids] | 93 | trainable_ids = self.trainable_ids[input_ids] |
@@ -104,16 +96,25 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
104 | 96 | ||
105 | elems = [self.lora_A[id] for id in trainable_ids] | 97 | elems = [self.lora_A[id] for id in trainable_ids] |
106 | 98 | ||
107 | if len(elems) == 0: | 99 | if len(elems) != 0: |
108 | return None, mask | 100 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling |
109 | 101 | else: | |
110 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 102 | weights = self.weight.new_zeros(self.embedding_dim) |
111 | 103 | ||
112 | return weights, mask | 104 | return weights, mask |
113 | 105 | ||
106 | def persist(self): | ||
107 | if self.r > 0: | ||
108 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
109 | if weights is not None: | ||
110 | self.weight[mask].data += weights | ||
111 | self.trainable_ids[:] = -1 | ||
112 | self.lora_A = nn.ParameterList() | ||
113 | |||
114 | def reset_parameters(self): | 114 | def reset_parameters(self): |
115 | nn.Embedding.reset_parameters(self) | 115 | nn.Embedding.reset_parameters(self) |
116 | if hasattr(self, 'lora_A'): | 116 | if hasattr(self, 'lora_A'): |
117 | self.trainable_ids[:] = -1 | ||
117 | self.lora_A = nn.ParameterList() | 118 | self.lora_A = nn.ParameterList() |
118 | nn.init.zeros_(self.lora_B.weight) | 119 | nn.init.zeros_(self.lora_B.weight) |
119 | 120 | ||
@@ -122,8 +123,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
122 | if self.merge_weights and self.merged: | 123 | if self.merge_weights and self.merged: |
123 | if self.r > 0: | 124 | if self.r > 0: |
124 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 125 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
125 | if weights is not None: | 126 | self.weight[mask].data -= weights |
126 | self.weight[mask].data -= weights | ||
127 | self.merged = False | 127 | self.merged = False |
128 | 128 | ||
129 | def eval(self): | 129 | def eval(self): |
@@ -131,16 +131,14 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
131 | if self.merge_weights and not self.merged: | 131 | if self.merge_weights and not self.merged: |
132 | if self.r > 0: | 132 | if self.r > 0: |
133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
134 | if weights is not None: | 134 | self.weight[mask].data += weights |
135 | self.weight[mask].data += weights | ||
136 | self.merged = True | 135 | self.merged = True |
137 | 136 | ||
138 | def forward(self, input_ids: torch.Tensor): | 137 | def forward(self, input_ids: torch.LongTensor): |
139 | result = nn.Embedding.forward(self, input_ids) | 138 | result = nn.Embedding.forward(self, input_ids) |
140 | 139 | ||
141 | if self.r > 0 and not self.merged: | 140 | if self.r > 0 and not self.merged: |
142 | weights, mask = self.get_weights(input_ids) | 141 | weights, mask = self.get_weights(input_ids) |
143 | if weights is not None: | 142 | result[mask] += weights |
144 | result[mask] += weights | ||
145 | 143 | ||
146 | return result | 144 | return result |