diff options
Diffstat (limited to 'training/lora.py')
-rw-r--r-- | training/lora.py | 107 |
1 files changed, 0 insertions, 107 deletions
diff --git a/training/lora.py b/training/lora.py deleted file mode 100644 index 3857d78..0000000 --- a/training/lora.py +++ /dev/null | |||
@@ -1,107 +0,0 @@ | |||
1 | import torch | ||
2 | import torch.nn as nn | ||
3 | |||
4 | from diffusers import ModelMixin, ConfigMixin | ||
5 | from diffusers.configuration_utils import register_to_config | ||
6 | from diffusers.models.cross_attention import CrossAttention | ||
7 | from diffusers.utils.import_utils import is_xformers_available | ||
8 | |||
9 | |||
10 | if is_xformers_available(): | ||
11 | import xformers | ||
12 | import xformers.ops | ||
13 | else: | ||
14 | xformers = None | ||
15 | |||
16 | |||
17 | class LoRALinearLayer(nn.Module): | ||
18 | def __init__(self, in_features, out_features, rank=4): | ||
19 | super().__init__() | ||
20 | |||
21 | if rank > min(in_features, out_features): | ||
22 | raise ValueError( | ||
23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" | ||
24 | ) | ||
25 | |||
26 | self.lora_down = nn.Linear(in_features, rank, bias=False) | ||
27 | self.lora_up = nn.Linear(rank, out_features, bias=False) | ||
28 | self.scale = 1.0 | ||
29 | |||
30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) | ||
31 | nn.init.zeros_(self.lora_up.weight) | ||
32 | |||
33 | def forward(self, hidden_states): | ||
34 | down_hidden_states = self.lora_down(hidden_states) | ||
35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
36 | |||
37 | return up_hidden_states | ||
38 | |||
39 | |||
40 | class LoRACrossAttnProcessor(nn.Module): | ||
41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
42 | super().__init__() | ||
43 | |||
44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
48 | |||
49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
50 | batch_size, sequence_length, _ = hidden_states.shape | ||
51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
52 | |||
53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
54 | query = attn.head_to_batch_dim(query) | ||
55 | |||
56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
57 | |||
58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
60 | |||
61 | key = attn.head_to_batch_dim(key) | ||
62 | value = attn.head_to_batch_dim(value) | ||
63 | |||
64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
65 | hidden_states = torch.bmm(attention_probs, value) | ||
66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
67 | |||
68 | # linear proj | ||
69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
70 | # dropout | ||
71 | hidden_states = attn.to_out[1](hidden_states) | ||
72 | |||
73 | return hidden_states | ||
74 | |||
75 | |||
76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
78 | super().__init__() | ||
79 | |||
80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
84 | |||
85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
86 | batch_size, sequence_length, _ = hidden_states.shape | ||
87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
88 | |||
89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
90 | query = attn.head_to_batch_dim(query).contiguous() | ||
91 | |||
92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
93 | |||
94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
96 | |||
97 | key = attn.head_to_batch_dim(key).contiguous() | ||
98 | value = attn.head_to_batch_dim(value).contiguous() | ||
99 | |||
100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
101 | |||
102 | # linear proj | ||
103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
104 | # dropout | ||
105 | hidden_states = attn.to_out[1](hidden_states) | ||
106 | |||
107 | return hidden_states | ||