import torch.nn as nn from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config class LoraAttention(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_features, out_features, r=4): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 1.0 self.processor = XFormersCrossAttnProcessor() nn.init.normal_(self.lora_down.weight, std=1 / r**2) nn.init.zeros_(self.lora_up.weight) def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale return hidden_states