summaryrefslogtreecommitdiffstats
path: root/training/lora.py
blob: e1c0971922096f32213b86472501c7d7dd8da53e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn as nn

from diffusers import ModelMixin, ConfigMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.cross_attention import CrossAttention
from diffusers.utils.import_utils import is_xformers_available


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None


class LoraAttnProcessor(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        cross_attention_dim,
        inner_dim,
        r: int = 4
    ):
        super().__init__()

        if r > min(cross_attention_dim, inner_dim):
            raise ValueError(
                f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}"
            )

        self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False)
        self.lora_k_up = nn.Linear(r, inner_dim, bias=False)

        self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False)
        self.lora_v_up = nn.Linear(r, inner_dim, bias=False)

        self.scale = 1.0

        nn.init.normal_(self.lora_k_down.weight, std=1 / r**2)
        nn.init.zeros_(self.lora_k_up.weight)

        nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
        nn.init.zeros_(self.lora_v_up.weight)

    def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
        value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale

        query = attn.head_to_batch_dim(query).contiguous()
        key = attn.head_to_batch_dim(key).contiguous()
        value = attn.head_to_batch_dim(value).contiguous()

        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states