From 403f525d0c6900cc6844c0d2f4ecb385fc131969 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 09:40:24 +0100 Subject: Fixed reproducibility, more consistant validation --- training/lora.py | 92 +++++++++++++++++++++++++++++++++++++++----------------- training/lr.py | 6 ++-- 2 files changed, 68 insertions(+), 30 deletions(-) (limited to 'training') diff --git a/training/lora.py b/training/lora.py index e1c0971..3857d78 100644 --- a/training/lora.py +++ b/training/lora.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from diffusers import ModelMixin, ConfigMixin @@ -13,56 +14,93 @@ else: xformers = None -class LoraAttnProcessor(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - cross_attention_dim, - inner_dim, - r: int = 4 - ): +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): super().__init__() - if r > min(cross_attention_dim, inner_dim): + if rank > min(in_features, out_features): raise ValueError( - f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" ) - self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) - self.lora_k_up = nn.Linear(r, inner_dim, bias=False) + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scale = 1.0 - self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) - self.lora_v_up = nn.Linear(r, inner_dim, bias=False) + nn.init.normal_(self.lora_down.weight, std=1 / rank) + nn.init.zeros_(self.lora_up.weight) - self.scale = 1.0 + def forward(self, hidden_states): + down_hidden_states = self.lora_down(hidden_states) + up_hidden_states = self.lora_up(down_hidden_states) - nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_k_up.weight) + return up_hidden_states - nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_v_up.weight) - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale - value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/training/lr.py b/training/lr.py index 37588b6..a3144ba 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,6 +1,6 @@ import math import copy -from typing import Callable +from typing import Callable, Any, Tuple, Union from functools import partial import matplotlib.pyplot as plt @@ -24,7 +24,7 @@ class LRFinder(): optimizer, train_dataloader, val_dataloader, - loss_fn, + loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], None] = noop, on_eval: Callable[[], None] = noop ): @@ -108,7 +108,7 @@ class LRFinder(): if step >= num_val_batches: break - loss, acc, bsz = self.loss_fn(batch) + loss, acc, bsz = self.loss_fn(batch, True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-70-g09d2