From 47b2fba05ba0f2f5335321d94e0634a7291980c5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Dec 2022 16:48:25 +0100 Subject: Some LoRA fixes (still broken) --- training/lora.py | 68 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 13 deletions(-) (limited to 'training') diff --git a/training/lora.py b/training/lora.py index d8dc147..e1c0971 100644 --- a/training/lora.py +++ b/training/lora.py @@ -1,27 +1,69 @@ import torch.nn as nn -from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config +from diffusers import ModelMixin, ConfigMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.cross_attention import CrossAttention +from diffusers.utils.import_utils import is_xformers_available -class LoraAttention(ModelMixin, ConfigMixin): + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class LoraAttnProcessor(ModelMixin, ConfigMixin): @register_to_config - def __init__(self, in_features, out_features, r=4): + def __init__( + self, + cross_attention_dim, + inner_dim, + r: int = 4 + ): super().__init__() - if r > min(in_features, out_features): + if r > min(cross_attention_dim, inner_dim): raise ValueError( - f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" ) - self.lora_down = nn.Linear(in_features, r, bias=False) - self.lora_up = nn.Linear(r, out_features, bias=False) + self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) + self.lora_k_up = nn.Linear(r, inner_dim, bias=False) + + self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) + self.lora_v_up = nn.Linear(r, inner_dim, bias=False) + self.scale = 1.0 - self.processor = XFormersCrossAttnProcessor() + nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) + nn.init.zeros_(self.lora_k_up.weight) + + nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) + nn.init.zeros_(self.lora_v_up.weight) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale + value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) - nn.init.normal_(self.lora_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_up.weight) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): - hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) - hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale return hidden_states -- cgit v1.2.3-70-g09d2