import torch.nn as nn from diffusers import ModelMixin, ConfigMixin from diffusers.configuration_utils import register_to_config from diffusers.models.cross_attention import CrossAttention from diffusers.utils.import_utils import is_xformers_available if is_xformers_available(): import xformers import xformers.ops else: xformers = None class LoraAttnProcessor(ModelMixin, ConfigMixin): @register_to_config def __init__( self, cross_attention_dim, inner_dim, r: int = 4 ): super().__init__() if r > min(cross_attention_dim, inner_dim): raise ValueError( f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" ) self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) self.lora_k_up = nn.Linear(r, inner_dim, bias=False) self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) self.lora_v_up = nn.Linear(r, inner_dim, bias=False) self.scale = 1.0 nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) nn.init.zeros_(self.lora_k_up.weight) nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) nn.init.zeros_(self.lora_v_up.weight) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states