diff options
| -rw-r--r-- | models/lora.py | 145 | ||||
| -rw-r--r-- | train_lora.py | 14 | ||||
| -rw-r--r-- | training/attention_processor.py | 47 | 
3 files changed, 50 insertions, 156 deletions
| diff --git a/models/lora.py b/models/lora.py deleted file mode 100644 index e506cff..0000000 --- a/models/lora.py +++ /dev/null | |||
| @@ -1,145 +0,0 @@ | |||
| 1 | from typing import Optional | ||
| 2 | import math | ||
| 3 | |||
| 4 | import torch | ||
| 5 | import torch.nn as nn | ||
| 6 | |||
| 7 | |||
| 8 | class LoraLayer(): | ||
| 9 | def __init__( | ||
| 10 | self, | ||
| 11 | r: int, | ||
| 12 | lora_alpha: int, | ||
| 13 | lora_dropout: float, | ||
| 14 | merge_weights: bool, | ||
| 15 | ): | ||
| 16 | self.r = r | ||
| 17 | self.lora_alpha = lora_alpha | ||
| 18 | self.lora_dropout_p = lora_dropout | ||
| 19 | |||
| 20 | if lora_dropout > 0.: | ||
| 21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
| 22 | else: | ||
| 23 | self.lora_dropout = nn.Identity() | ||
| 24 | |||
| 25 | self.merged = False | ||
| 26 | self.merge_weights = merge_weights | ||
| 27 | |||
| 28 | |||
| 29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
| 30 | def __init__( | ||
| 31 | self, | ||
| 32 | num_embeddings: int, | ||
| 33 | embedding_dim: int, | ||
| 34 | r: int = 0, | ||
| 35 | lora_alpha: int = 1, | ||
| 36 | lora_dropout: float = 0.0, | ||
| 37 | merge_weights: bool = True, | ||
| 38 | **kwargs | ||
| 39 | ): | ||
| 40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
| 41 | LoraLayer.__init__( | ||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
| 43 | ) | ||
| 44 | |||
| 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
| 46 | |||
| 47 | self.lora_A = nn.ParameterList() | ||
| 48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | ||
| 49 | self.scaling = self.lora_alpha / self.r | ||
| 50 | self.weight.requires_grad = False | ||
| 51 | |||
| 52 | self.reset_parameters() | ||
| 53 | |||
| 54 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
| 55 | n = min(self.num_embeddings, new_num_embeddings) | ||
| 56 | |||
| 57 | new_emb = LoraEmbedding( | ||
| 58 | new_num_embeddings, | ||
| 59 | self.embedding_dim, | ||
| 60 | self.r, | ||
| 61 | self.lora_alpha, | ||
| 62 | self.lora_dropout_p, | ||
| 63 | device=self.weight.device, | ||
| 64 | dtype=self.weight.dtype | ||
| 65 | ) | ||
| 66 | if initializer_factor is not None: | ||
| 67 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 68 | else: | ||
| 69 | nn.init.zeros_(new_emb.weight.data) | ||
| 70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
| 71 | for param in self.lora_A: | ||
| 72 | new_emb.lora_A.append(param) | ||
| 73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
| 74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
| 75 | |||
| 76 | return new_emb | ||
| 77 | |||
| 78 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
| 79 | trainable_ids = self.trainable_ids[input_ids] | ||
| 80 | new_ids = input_ids[trainable_ids == -1] | ||
| 81 | |||
| 82 | if new_ids.shape[0] == 0: | ||
| 83 | return | ||
| 84 | |||
| 85 | n1 = len(self.lora_A) | ||
| 86 | n2 = n1 + new_ids.shape[0] | ||
| 87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
| 88 | for _ in new_ids: | ||
| 89 | w = self.weight.new_zeros(self.r) | ||
| 90 | self.lora_A.append(w) | ||
| 91 | |||
| 92 | if len(self.lora_A) > 1: | ||
| 93 | elems = torch.stack([param for param in self.lora_A]) | ||
| 94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
| 95 | |||
| 96 | def get_weights(self, input_ids: torch.Tensor): | ||
| 97 | if len(input_ids.shape) != 1: | ||
| 98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | ||
| 99 | |||
| 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
| 101 | |||
| 102 | if not self.merged: | ||
| 103 | trainable_ids = self.trainable_ids[input_ids] | ||
| 104 | mask = ~(trainable_ids == -1) | ||
| 105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
| 106 | |||
| 107 | if len(elems) != 0: | ||
| 108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
| 109 | weights[mask] = w.to(dtype=weights.dtype) | ||
| 110 | |||
| 111 | return weights | ||
| 112 | |||
| 113 | def persist(self): | ||
| 114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 115 | self.trainable_ids[:] = -1 | ||
| 116 | self.lora_A = nn.ParameterList() | ||
| 117 | nn.init.zeros_(self.lora_B.weight) | ||
| 118 | |||
| 119 | def reset_parameters(self): | ||
| 120 | nn.Embedding.reset_parameters(self) | ||
| 121 | if hasattr(self, "lora_A"): | ||
| 122 | self.trainable_ids[:] = -1 | ||
| 123 | self.lora_A = nn.ParameterList() | ||
| 124 | nn.init.zeros_(self.lora_B.weight) | ||
| 125 | |||
| 126 | def train(self, mode: bool = True): | ||
| 127 | nn.Embedding.train(self, mode) | ||
| 128 | self.lora_A.train(mode) | ||
| 129 | self.lora_B.train(mode) | ||
| 130 | if not mode and self.merge_weights and not self.merged: | ||
| 131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 132 | self.merged = True | ||
| 133 | elif self.merge_weights and self.merged: | ||
| 134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 135 | self.merged = False | ||
| 136 | |||
| 137 | def eval(self): | ||
| 138 | nn.Embedding.eval(self) | ||
| 139 | self.lora_A.eval() | ||
| 140 | self.lora_B.eval() | ||
| 141 | |||
| 142 | def forward(self, input_ids: torch.LongTensor): | ||
| 143 | result = nn.Embedding.forward(self, input_ids) | ||
| 144 | result += self.get_weights(input_ids) | ||
| 145 | return result | ||
| diff --git a/train_lora.py b/train_lora.py index 737af58..dea58cf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -15,7 +15,7 @@ import hidet | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator | 
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger | 
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed | 
| 18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, get_peft_model | 
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor | 
| 20 | import transformers | 20 | import transformers | 
| 21 | 21 | ||
| @@ -731,7 +731,7 @@ def main(): | |||
| 731 | lora_dropout=args.lora_dropout, | 731 | lora_dropout=args.lora_dropout, | 
| 732 | bias=args.lora_bias, | 732 | bias=args.lora_bias, | 
| 733 | ) | 733 | ) | 
| 734 | unet = LoraModel(unet_config, unet) | 734 | unet = get_peft_model(unet, unet_config) | 
| 735 | 735 | ||
| 736 | text_encoder_config = LoraConfig( | 736 | text_encoder_config = LoraConfig( | 
| 737 | r=args.lora_text_encoder_r, | 737 | r=args.lora_text_encoder_r, | 
| @@ -740,7 +740,7 @@ def main(): | |||
| 740 | lora_dropout=args.lora_text_encoder_dropout, | 740 | lora_dropout=args.lora_text_encoder_dropout, | 
| 741 | bias=args.lora_text_encoder_bias, | 741 | bias=args.lora_text_encoder_bias, | 
| 742 | ) | 742 | ) | 
| 743 | text_encoder = LoraModel(text_encoder_config, text_encoder) | 743 | text_encoder = get_peft_model(text_encoder, text_encoder_config) | 
| 744 | 744 | ||
| 745 | vae.enable_slicing() | 745 | vae.enable_slicing() | 
| 746 | 746 | ||
| @@ -1167,14 +1167,6 @@ def main(): | |||
| 1167 | group_labels.append("unet") | 1167 | group_labels.append("unet") | 
| 1168 | 1168 | ||
| 1169 | if training_iter < args.train_text_encoder_cycles: | 1169 | if training_iter < args.train_text_encoder_cycles: | 
| 1170 | # if len(placeholder_tokens) != 0: | ||
| 1171 | # params_to_optimize.append({ | ||
| 1172 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 1173 | # "lr": learning_rate_emb, | ||
| 1174 | # "weight_decay": 0, | ||
| 1175 | # }) | ||
| 1176 | # group_labels.append("emb") | ||
| 1177 | |||
| 1178 | params_to_optimize.append({ | 1170 | params_to_optimize.append({ | 
| 1179 | "params": ( | 1171 | "params": ( | 
| 1180 | param | 1172 | param | 
| diff --git a/training/attention_processor.py b/training/attention_processor.py new file mode 100644 index 0000000..4309bd4 --- /dev/null +++ b/training/attention_processor.py | |||
| @@ -0,0 +1,47 @@ | |||
| 1 | from typing import Callable, Optional, Union | ||
| 2 | |||
| 3 | import xformers | ||
| 4 | import xformers.ops | ||
| 5 | |||
| 6 | from diffusers.models.attention_processor import Attention | ||
| 7 | |||
| 8 | |||
| 9 | class XFormersAttnProcessor: | ||
| 10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
| 11 | self.attention_op = attention_op | ||
| 12 | |||
| 13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
| 14 | batch_size, sequence_length, _ = ( | ||
| 15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
| 16 | ) | ||
| 17 | |||
| 18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
| 19 | |||
| 20 | query = attn.to_q(hidden_states) | ||
| 21 | |||
| 22 | if encoder_hidden_states is None: | ||
| 23 | encoder_hidden_states = hidden_states | ||
| 24 | elif attn.norm_cross: | ||
| 25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
| 26 | |||
| 27 | key = attn.to_k(encoder_hidden_states) | ||
| 28 | value = attn.to_v(encoder_hidden_states) | ||
| 29 | |||
| 30 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 31 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 32 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 33 | |||
| 34 | query = query.to(key.dtype) | ||
| 35 | value = value.to(key.dtype) | ||
| 36 | |||
| 37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
| 38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
| 39 | ) | ||
| 40 | hidden_states = hidden_states.to(query.dtype) | ||
| 41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 42 | |||
| 43 | # linear proj | ||
| 44 | hidden_states = attn.to_out[0](hidden_states) | ||
| 45 | # dropout | ||
| 46 | hidden_states = attn.to_out[1](hidden_states) | ||
| 47 | return hidden_states | ||
