From 1d5038280d44a36351cb3aa21aad7a8eff220c94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 20 Dec 2022 22:07:06 +0100 Subject: Fix training --- training/lora.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 training/lora.py (limited to 'training') diff --git a/training/lora.py b/training/lora.py new file mode 100644 index 0000000..d8dc147 --- /dev/null +++ b/training/lora.py @@ -0,0 +1,27 @@ +import torch.nn as nn +from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config + + +class LoraAttention(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, in_features, out_features, r=4): + super().__init__() + + if r > min(in_features, out_features): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + ) + + self.lora_down = nn.Linear(in_features, r, bias=False) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = 1.0 + + self.processor = XFormersCrossAttnProcessor() + + nn.init.normal_(self.lora_down.weight, std=1 / r**2) + nn.init.zeros_(self.lora_up.weight) + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) + hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale + return hidden_states -- cgit v1.2.3-70-g09d2