1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.cross_attention import CrossAttention
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class LoraAttnProcessor(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
cross_attention_dim,
inner_dim,
r: int = 4
):
super().__init__()
if r > min(cross_attention_dim, inner_dim):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}"
)
self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False)
self.lora_k_up = nn.Linear(r, inner_dim, bias=False)
self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False)
self.lora_v_up = nn.Linear(r, inner_dim, bias=False)
self.scale = 1.0
nn.init.normal_(self.lora_k_down.weight, std=1 / r**2)
nn.init.zeros_(self.lora_k_up.weight)
nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
nn.init.zeros_(self.lora_v_up.weight)
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
|