summaryrefslogtreecommitdiffstats
path: root/training/lora.py
blob: d8dc147a7d19c114ec470c507b034f602d8e8e6d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config


class LoraAttention(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(self, in_features, out_features, r=4):
        super().__init__()

        if r > min(in_features, out_features):
            raise ValueError(
                f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
            )

        self.lora_down = nn.Linear(in_features, r, bias=False)
        self.lora_up = nn.Linear(r, out_features, bias=False)
        self.scale = 1.0

        self.processor = XFormersCrossAttnProcessor()

        nn.init.normal_(self.lora_down.weight, std=1 / r**2)
        nn.init.zeros_(self.lora_up.weight)

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
        hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number)
        hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale
        return hidden_states