summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/lora.py145
-rw-r--r--train_lora.py14
-rw-r--r--training/attention_processor.py47
3 files changed, 50 insertions, 156 deletions
diff --git a/models/lora.py b/models/lora.py
deleted file mode 100644
index e506cff..0000000
--- a/models/lora.py
+++ /dev/null
@@ -1,145 +0,0 @@
1from typing import Optional
2import math
3
4import torch
5import torch.nn as nn
6
7
8class LoraLayer():
9 def __init__(
10 self,
11 r: int,
12 lora_alpha: int,
13 lora_dropout: float,
14 merge_weights: bool,
15 ):
16 self.r = r
17 self.lora_alpha = lora_alpha
18 self.lora_dropout_p = lora_dropout
19
20 if lora_dropout > 0.:
21 self.lora_dropout = nn.Dropout(p=lora_dropout)
22 else:
23 self.lora_dropout = nn.Identity()
24
25 self.merged = False
26 self.merge_weights = merge_weights
27
28
29class LoraEmbedding(nn.Embedding, LoraLayer):
30 def __init__(
31 self,
32 num_embeddings: int,
33 embedding_dim: int,
34 r: int = 0,
35 lora_alpha: int = 1,
36 lora_dropout: float = 0.0,
37 merge_weights: bool = True,
38 **kwargs
39 ):
40 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
41 LoraLayer.__init__(
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 )
44
45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
46
47 self.lora_A = nn.ParameterList()
48 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
49 self.scaling = self.lora_alpha / self.r
50 self.weight.requires_grad = False
51
52 self.reset_parameters()
53
54 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
55 n = min(self.num_embeddings, new_num_embeddings)
56
57 new_emb = LoraEmbedding(
58 new_num_embeddings,
59 self.embedding_dim,
60 self.r,
61 self.lora_alpha,
62 self.lora_dropout_p,
63 device=self.weight.device,
64 dtype=self.weight.dtype
65 )
66 if initializer_factor is not None:
67 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
68 else:
69 nn.init.zeros_(new_emb.weight.data)
70 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
71 for param in self.lora_A:
72 new_emb.lora_A.append(param)
73 new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
74 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
75
76 return new_emb
77
78 def mark_trainable(self, input_ids: torch.LongTensor):
79 trainable_ids = self.trainable_ids[input_ids]
80 new_ids = input_ids[trainable_ids == -1]
81
82 if new_ids.shape[0] == 0:
83 return
84
85 n1 = len(self.lora_A)
86 n2 = n1 + new_ids.shape[0]
87 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 for _ in new_ids:
89 w = self.weight.new_zeros(self.r)
90 self.lora_A.append(w)
91
92 if len(self.lora_A) > 1:
93 elems = torch.stack([param for param in self.lora_A])
94 nn.init.kaiming_uniform_(elems, a=math.sqrt(5))
95
96 def get_weights(self, input_ids: torch.Tensor):
97 if len(input_ids.shape) != 1:
98 return torch.stack([self.get_weights(batch) for batch in input_ids])
99
100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
101
102 if not self.merged:
103 trainable_ids = self.trainable_ids[input_ids]
104 mask = ~(trainable_ids == -1)
105 elems = [self.lora_A[id] for id in trainable_ids[mask]]
106
107 if len(elems) != 0:
108 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
109 weights[mask] = w.to(dtype=weights.dtype)
110
111 return weights
112
113 def persist(self):
114 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
115 self.trainable_ids[:] = -1
116 self.lora_A = nn.ParameterList()
117 nn.init.zeros_(self.lora_B.weight)
118
119 def reset_parameters(self):
120 nn.Embedding.reset_parameters(self)
121 if hasattr(self, "lora_A"):
122 self.trainable_ids[:] = -1
123 self.lora_A = nn.ParameterList()
124 nn.init.zeros_(self.lora_B.weight)
125
126 def train(self, mode: bool = True):
127 nn.Embedding.train(self, mode)
128 self.lora_A.train(mode)
129 self.lora_B.train(mode)
130 if not mode and self.merge_weights and not self.merged:
131 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
132 self.merged = True
133 elif self.merge_weights and self.merged:
134 self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
135 self.merged = False
136
137 def eval(self):
138 nn.Embedding.eval(self)
139 self.lora_A.eval()
140 self.lora_B.eval()
141
142 def forward(self, input_ids: torch.LongTensor):
143 result = nn.Embedding.forward(self, input_ids)
144 result += self.get_weights(input_ids)
145 return result
diff --git a/train_lora.py b/train_lora.py
index 737af58..dea58cf 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -15,7 +15,7 @@ import hidet
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, LoraModel 18from peft import LoraConfig, get_peft_model
19# from diffusers.models.attention_processor import AttnProcessor 19# from diffusers.models.attention_processor import AttnProcessor
20import transformers 20import transformers
21 21
@@ -731,7 +731,7 @@ def main():
731 lora_dropout=args.lora_dropout, 731 lora_dropout=args.lora_dropout,
732 bias=args.lora_bias, 732 bias=args.lora_bias,
733 ) 733 )
734 unet = LoraModel(unet_config, unet) 734 unet = get_peft_model(unet, unet_config)
735 735
736 text_encoder_config = LoraConfig( 736 text_encoder_config = LoraConfig(
737 r=args.lora_text_encoder_r, 737 r=args.lora_text_encoder_r,
@@ -740,7 +740,7 @@ def main():
740 lora_dropout=args.lora_text_encoder_dropout, 740 lora_dropout=args.lora_text_encoder_dropout,
741 bias=args.lora_text_encoder_bias, 741 bias=args.lora_text_encoder_bias,
742 ) 742 )
743 text_encoder = LoraModel(text_encoder_config, text_encoder) 743 text_encoder = get_peft_model(text_encoder, text_encoder_config)
744 744
745 vae.enable_slicing() 745 vae.enable_slicing()
746 746
@@ -1167,14 +1167,6 @@ def main():
1167 group_labels.append("unet") 1167 group_labels.append("unet")
1168 1168
1169 if training_iter < args.train_text_encoder_cycles: 1169 if training_iter < args.train_text_encoder_cycles:
1170 # if len(placeholder_tokens) != 0:
1171 # params_to_optimize.append({
1172 # "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1173 # "lr": learning_rate_emb,
1174 # "weight_decay": 0,
1175 # })
1176 # group_labels.append("emb")
1177
1178 params_to_optimize.append({ 1170 params_to_optimize.append({
1179 "params": ( 1171 "params": (
1180 param 1172 param
diff --git a/training/attention_processor.py b/training/attention_processor.py
new file mode 100644
index 0000000..4309bd4
--- /dev/null
+++ b/training/attention_processor.py
@@ -0,0 +1,47 @@
1from typing import Callable, Optional, Union
2
3import xformers
4import xformers.ops
5
6from diffusers.models.attention_processor import Attention
7
8
9class XFormersAttnProcessor:
10 def __init__(self, attention_op: Optional[Callable] = None):
11 self.attention_op = attention_op
12
13 def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
14 batch_size, sequence_length, _ = (
15 hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
16 )
17
18 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
19
20 query = attn.to_q(hidden_states)
21
22 if encoder_hidden_states is None:
23 encoder_hidden_states = hidden_states
24 elif attn.norm_cross:
25 encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
26
27 key = attn.to_k(encoder_hidden_states)
28 value = attn.to_v(encoder_hidden_states)
29
30 query = attn.head_to_batch_dim(query).contiguous()
31 key = attn.head_to_batch_dim(key).contiguous()
32 value = attn.head_to_batch_dim(value).contiguous()
33
34 query = query.to(key.dtype)
35 value = value.to(key.dtype)
36
37 hidden_states = xformers.ops.memory_efficient_attention(
38 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
39 )
40 hidden_states = hidden_states.to(query.dtype)
41 hidden_states = attn.batch_to_head_dim(hidden_states)
42
43 # linear proj
44 hidden_states = attn.to_out[0](hidden_states)
45 # dropout
46 hidden_states = attn.to_out[1](hidden_states)
47 return hidden_states