diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 41 | ||||
| -rw-r--r-- | models/lora.py | 77 | ||||
| -rw-r--r-- | models/sparse.py | 110 | 
3 files changed, 157 insertions, 71 deletions
| diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -10,23 +10,21 @@ from transformers import CLIPTextModel | |||
| 10 | from transformers.models.clip import CLIPTextConfig | 10 | from transformers.models.clip import CLIPTextConfig | 
| 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 
| 12 | 12 | ||
| 13 | from models.lora import LoraEmbedding | 13 | from models.sparse import SparseEmbedding | 
| 14 | 14 | ||
| 15 | 15 | ||
| 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 
| 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): | 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): | 
| 18 | super().__init__(config) | 18 | super().__init__(config) | 
| 19 | 19 | ||
| 20 | self.position_embedding = embeddings.position_embedding | 20 | self.position_embedding = embeddings.position_embedding | 
| 21 | self.initializer_factor = config.initializer_factor | 21 | self.initializer_factor = config.initializer_factor | 
| 22 | self.token_embedding = LoraEmbedding( | 22 | self.token_embedding = SparseEmbedding( | 
| 23 | self.token_embedding.num_embeddings, | 23 | self.token_embedding.num_embeddings, | 
| 24 | self.token_embedding.embedding_dim, | 24 | self.token_embedding.embedding_dim, | 
| 25 | r, | 25 | alpha, | 
| 26 | lora_alpha, | 26 | dropout, | 
| 27 | lora_dropout, | ||
| 28 | ) | 27 | ) | 
| 29 | |||
| 30 | self.token_embedding.weight = embeddings.token_embedding.weight | 28 | self.token_embedding.weight = embeddings.token_embedding.weight | 
| 31 | 29 | ||
| 32 | def resize(self, size: int): | 30 | def resize(self, size: int): | 
| @@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 82 | 80 | ||
| 83 | return self.token_embedding(input_ids) | 81 | return self.token_embedding(input_ids) | 
| 84 | 82 | ||
| 85 | def forward( | ||
| 86 | self, | ||
| 87 | input_ids: Optional[torch.LongTensor] = None, | ||
| 88 | position_ids: Optional[torch.LongTensor] = None, | ||
| 89 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 90 | ) -> torch.Tensor: | ||
| 91 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 92 | |||
| 93 | if position_ids is None: | ||
| 94 | position_ids = self.position_ids[:, :seq_length] | ||
| 95 | |||
| 96 | if inputs_embeds is None: | ||
| 97 | inputs_embeds = self.get_embed(input_ids) | ||
| 98 | |||
| 99 | position_embeddings = self.position_embedding(position_ids) | ||
| 100 | embeddings = inputs_embeds + position_embeddings | ||
| 101 | |||
| 102 | return embeddings | ||
| 103 | |||
| 104 | 83 | ||
| 105 | def patch_managed_embeddings( | 84 | def patch_managed_embeddings( | 
| 106 | text_encoder: CLIPTextModel, | 85 | text_encoder: CLIPTextModel, | 
| 107 | r: int = 8, | 86 | alpha: int = 8, | 
| 108 | lora_alpha: int = 8, | 87 | dropout: float = 0.0 | 
| 109 | lora_dropout: float = 0.0 | ||
| 110 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: | 
| 111 | text_embeddings = ManagedCLIPTextEmbeddings( | 89 | text_embeddings = ManagedCLIPTextEmbeddings( | 
| 112 | text_encoder.config, | 90 | text_encoder.config, | 
| 113 | text_encoder.text_model.embeddings, | 91 | text_encoder.text_model.embeddings, | 
| 114 | r, | 92 | alpha, | 
| 115 | lora_alpha, | 93 | dropout | 
| 116 | lora_dropout | ||
| 117 | ) | 94 | ) | 
| 118 | text_encoder.text_model.embeddings = text_embeddings | 95 | text_encoder.text_model.embeddings = text_embeddings | 
| 119 | return text_embeddings | 96 | return text_embeddings | 
| diff --git a/models/lora.py b/models/lora.py index 01a540b..e506cff 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -1,8 +1,8 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional | 
| 2 | import math | ||
| 2 | 3 | ||
| 3 | import torch | 4 | import torch | 
| 4 | import torch.nn as nn | 5 | import torch.nn as nn | 
| 5 | import torch.nn.functional as F | ||
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | class LoraLayer(): | 8 | class LoraLayer(): | 
| @@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 
| 43 | ) | 43 | ) | 
| 44 | 44 | ||
| 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 
| 46 | self.trainable_ids -= 1 | ||
| 47 | 46 | ||
| 48 | if r > 0: | 47 | self.lora_A = nn.ParameterList() | 
| 49 | self.lora_A = nn.ParameterList() | 48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | 
| 50 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | 49 | self.scaling = self.lora_alpha / self.r | 
| 51 | self.scaling = self.lora_alpha / self.r | 50 | self.weight.requires_grad = False | 
| 52 | self.weight.requires_grad = False | ||
| 53 | 51 | ||
| 54 | self.reset_parameters() | 52 | self.reset_parameters() | 
| 55 | 53 | ||
| @@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 70 | else: | 68 | else: | 
| 71 | nn.init.zeros_(new_emb.weight.data) | 69 | nn.init.zeros_(new_emb.weight.data) | 
| 72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | 70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | 
| 73 | new_emb.lora_A = self.lora_A | 71 | for param in self.lora_A: | 
| 74 | new_emb.lora_B = self.lora_B | 72 | new_emb.lora_A.append(param) | 
| 73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
| 75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | 74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | 
| 76 | 75 | ||
| 77 | return new_emb | 76 | return new_emb | 
| @@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 87 | n2 = n1 + new_ids.shape[0] | 86 | n2 = n1 + new_ids.shape[0] | 
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 
| 89 | for _ in new_ids: | 88 | for _ in new_ids: | 
| 90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) | 89 | w = self.weight.new_zeros(self.r) | 
| 90 | self.lora_A.append(w) | ||
| 91 | |||
| 92 | if len(self.lora_A) > 1: | ||
| 93 | elems = torch.stack([param for param in self.lora_A]) | ||
| 94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
| 91 | 95 | ||
| 92 | def get_weights(self, input_ids: torch.Tensor): | 96 | def get_weights(self, input_ids: torch.Tensor): | 
| 93 | if len(input_ids.shape) != 1: | 97 | if len(input_ids.shape) != 1: | 
| 94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | 98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | 
| 95 | 99 | ||
| 96 | trainable_ids = self.trainable_ids[input_ids] | ||
| 97 | mask = ~(trainable_ids == -1) | ||
| 98 | trainable_ids = trainable_ids[mask] | ||
| 99 | |||
| 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | 
| 101 | elems = [self.lora_A[id] for id in trainable_ids] | ||
| 102 | 101 | ||
| 103 | if len(elems) != 0: | 102 | if not self.merged: | 
| 104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 103 | trainable_ids = self.trainable_ids[input_ids] | 
| 105 | weights[mask] = w.to(dtype=weights.dtype) | 104 | mask = ~(trainable_ids == -1) | 
| 105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
| 106 | |||
| 107 | if len(elems) != 0: | ||
| 108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
| 109 | weights[mask] = w.to(dtype=weights.dtype) | ||
| 106 | 110 | ||
| 107 | return weights | 111 | return weights | 
| 108 | 112 | ||
| 109 | def persist(self): | 113 | def persist(self): | 
| 110 | if self.r > 0: | 114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 
| 111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 115 | self.trainable_ids[:] = -1 | 
| 112 | self.weight.data += weights | 116 | self.lora_A = nn.ParameterList() | 
| 113 | self.trainable_ids[:] = -1 | 117 | nn.init.zeros_(self.lora_B.weight) | 
| 114 | self.lora_A = nn.ParameterList() | ||
| 115 | 118 | ||
| 116 | def reset_parameters(self): | 119 | def reset_parameters(self): | 
| 117 | nn.Embedding.reset_parameters(self) | 120 | nn.Embedding.reset_parameters(self) | 
| 118 | if hasattr(self, 'lora_A'): | 121 | if hasattr(self, "lora_A"): | 
| 119 | self.trainable_ids[:] = -1 | 122 | self.trainable_ids[:] = -1 | 
| 120 | self.lora_A = nn.ParameterList() | 123 | self.lora_A = nn.ParameterList() | 
| 121 | nn.init.zeros_(self.lora_B.weight) | 124 | nn.init.zeros_(self.lora_B.weight) | 
| 122 | 125 | ||
| 123 | def train(self, mode: bool = True): | 126 | def train(self, mode: bool = True): | 
| 124 | nn.Embedding.train(self, mode) | 127 | nn.Embedding.train(self, mode) | 
| 125 | if self.merge_weights and self.merged: | 128 | self.lora_A.train(mode) | 
| 126 | if self.r > 0: | 129 | self.lora_B.train(mode) | 
| 127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 130 | if not mode and self.merge_weights and not self.merged: | 
| 128 | self.weight.data -= weights | 131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 
| 132 | self.merged = True | ||
| 133 | elif self.merge_weights and self.merged: | ||
| 134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 129 | self.merged = False | 135 | self.merged = False | 
| 130 | 136 | ||
| 131 | def eval(self): | 137 | def eval(self): | 
| 132 | nn.Embedding.eval(self) | 138 | nn.Embedding.eval(self) | 
| 133 | if self.merge_weights and not self.merged: | 139 | self.lora_A.eval() | 
| 134 | if self.r > 0: | 140 | self.lora_B.eval() | 
| 135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 136 | self.weight.data += weights | ||
| 137 | self.merged = True | ||
| 138 | 141 | ||
| 139 | def forward(self, input_ids: torch.LongTensor): | 142 | def forward(self, input_ids: torch.LongTensor): | 
| 140 | result = nn.Embedding.forward(self, input_ids) | 143 | result = nn.Embedding.forward(self, input_ids) | 
| 141 | 144 | result += self.get_weights(input_ids) | |
| 142 | if self.r > 0 and not self.merged: | ||
| 143 | weights = self.get_weights(input_ids) | ||
| 144 | result += weights | ||
| 145 | |||
| 146 | return result | 145 | return result | 
| diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..bd45696 --- /dev/null +++ b/models/sparse.py | |||
| @@ -0,0 +1,110 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | |||
| 7 | class SparseEmbedding(nn.Embedding): | ||
| 8 | def __init__( | ||
| 9 | self, | ||
| 10 | num_embeddings: int, | ||
| 11 | embedding_dim: int, | ||
| 12 | alpha: int = 1, | ||
| 13 | dropout: float = 0.0, | ||
| 14 | **kwargs | ||
| 15 | ): | ||
| 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
| 17 | |||
| 18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
| 19 | |||
| 20 | self.trainable = nn.ParameterList() | ||
| 21 | self.scaling = alpha | ||
| 22 | self.dropout_p = dropout | ||
| 23 | self.weight.requires_grad = False | ||
| 24 | |||
| 25 | if dropout > 0.: | ||
| 26 | self.dropout = nn.Dropout(p=dropout) | ||
| 27 | else: | ||
| 28 | self.dropout = nn.Identity() | ||
| 29 | |||
| 30 | self.reset_parameters() | ||
| 31 | |||
| 32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
| 33 | n = min(self.num_embeddings, new_num_embeddings) | ||
| 34 | |||
| 35 | new_emb = SparseEmbedding( | ||
| 36 | new_num_embeddings, | ||
| 37 | self.embedding_dim, | ||
| 38 | self.scaling, | ||
| 39 | self.dropout_p, | ||
| 40 | device=self.weight.device, | ||
| 41 | dtype=self.weight.dtype | ||
| 42 | ) | ||
| 43 | if initializer_factor is not None: | ||
| 44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 45 | else: | ||
| 46 | nn.init.zeros_(new_emb.weight.data) | ||
| 47 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
| 48 | for param in self.trainable: | ||
| 49 | new_emb.trainable.append(param) | ||
| 50 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
| 51 | |||
| 52 | return new_emb | ||
| 53 | |||
| 54 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
| 55 | trainable_ids = self.trainable_ids[input_ids] | ||
| 56 | new_ids = input_ids[trainable_ids == -1] | ||
| 57 | |||
| 58 | if new_ids.shape[0] == 0: | ||
| 59 | return | ||
| 60 | |||
| 61 | n1 = len(self.trainable) | ||
| 62 | n2 = n1 + new_ids.shape[0] | ||
| 63 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
| 64 | for _ in new_ids: | ||
| 65 | self.trainable.append(self.weight.new_zeros(self.embedding_dim)) | ||
| 66 | |||
| 67 | def get_weights(self, input_ids: torch.Tensor): | ||
| 68 | original_shape = input_ids.shape | ||
| 69 | |||
| 70 | if len(input_ids.shape) != 1: | ||
| 71 | input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) | ||
| 72 | |||
| 73 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
| 74 | |||
| 75 | trainable_ids = self.trainable_ids[input_ids] | ||
| 76 | mask = ~(trainable_ids == -1) | ||
| 77 | elems = [self.trainable[id] for id in trainable_ids[mask]] | ||
| 78 | |||
| 79 | if len(elems) != 0: | ||
| 80 | w = self.dropout(torch.stack(elems)) * self.scaling | ||
| 81 | weights[mask] = w.to(dtype=weights.dtype) | ||
| 82 | |||
| 83 | if len(original_shape) != 1: | ||
| 84 | weights = weights.view(original_shape[0], original_shape[1], -1) | ||
| 85 | |||
| 86 | return weights | ||
| 87 | |||
| 88 | def persist(self): | ||
| 89 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 90 | self.trainable_ids[:] = -1 | ||
| 91 | self.trainable = nn.ParameterList() | ||
| 92 | |||
| 93 | def reset_parameters(self): | ||
| 94 | nn.Embedding.reset_parameters(self) | ||
| 95 | if hasattr(self, "trainable"): | ||
| 96 | self.trainable_ids[:] = -1 | ||
| 97 | self.trainable = nn.ParameterList() | ||
| 98 | |||
| 99 | def train(self, mode: bool = True): | ||
| 100 | nn.Embedding.train(self, mode) | ||
| 101 | self.trainable.train(mode) | ||
| 102 | |||
| 103 | def eval(self): | ||
| 104 | nn.Embedding.eval(self) | ||
| 105 | self.trainable.eval() | ||
| 106 | |||
| 107 | def forward(self, input_ids: torch.LongTensor): | ||
| 108 | result = nn.Embedding.forward(self, input_ids) | ||
| 109 | result += self.get_weights(input_ids) | ||
| 110 | return result | ||
