From daba86ebbfa821f3c3227bcfbcbd532051e793e7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 12 May 2023 18:06:13 +0200 Subject: Fix for latest PEFT --- models/lora.py | 145 ---------------------------------------- train_lora.py | 14 +--- training/attention_processor.py | 47 +++++++++++++ 3 files changed, 50 insertions(+), 156 deletions(-) delete mode 100644 models/lora.py create mode 100644 training/attention_processor.py diff --git a/models/lora.py b/models/lora.py deleted file mode 100644 index e506cff..0000000 --- a/models/lora.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import Optional -import math - -import torch -import torch.nn as nn - - -class LoraLayer(): - def __init__( - self, - r: int, - lora_alpha: int, - lora_dropout: float, - merge_weights: bool, - ): - self.r = r - self.lora_alpha = lora_alpha - self.lora_dropout_p = lora_dropout - - if lora_dropout > 0.: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = nn.Identity() - - self.merged = False - self.merge_weights = merge_weights - - -class LoraEmbedding(nn.Embedding, LoraLayer): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - merge_weights: bool = True, - **kwargs - ): - nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - LoraLayer.__init__( - self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights - ) - - self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) - - self.lora_A = nn.ParameterList() - self.lora_B = nn.Linear(r, embedding_dim, bias=False) - self.scaling = self.lora_alpha / self.r - self.weight.requires_grad = False - - self.reset_parameters() - - def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): - n = min(self.num_embeddings, new_num_embeddings) - - new_emb = LoraEmbedding( - new_num_embeddings, - self.embedding_dim, - self.r, - self.lora_alpha, - self.lora_dropout_p, - device=self.weight.device, - dtype=self.weight.dtype - ) - if initializer_factor is not None: - new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) - else: - nn.init.zeros_(new_emb.weight.data) - new_emb.weight.data[:n, :] = self.weight.data[:n, :] - for param in self.lora_A: - new_emb.lora_A.append(param) - new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data - new_emb.trainable_ids[:n] = self.trainable_ids[:n] - - return new_emb - - def mark_trainable(self, input_ids: torch.LongTensor): - trainable_ids = self.trainable_ids[input_ids] - new_ids = input_ids[trainable_ids == -1] - - if new_ids.shape[0] == 0: - return - - n1 = len(self.lora_A) - n2 = n1 + new_ids.shape[0] - self.trainable_ids[new_ids] = torch.arange(n1, n2) - for _ in new_ids: - w = self.weight.new_zeros(self.r) - self.lora_A.append(w) - - if len(self.lora_A) > 1: - elems = torch.stack([param for param in self.lora_A]) - nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) - - def get_weights(self, input_ids: torch.Tensor): - if len(input_ids.shape) != 1: - return torch.stack([self.get_weights(batch) for batch in input_ids]) - - weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) - - if not self.merged: - trainable_ids = self.trainable_ids[input_ids] - mask = ~(trainable_ids == -1) - elems = [self.lora_A[id] for id in trainable_ids[mask]] - - if len(elems) != 0: - w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling - weights[mask] = w.to(dtype=weights.dtype) - - return weights - - def persist(self): - self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.trainable_ids[:] = -1 - self.lora_A = nn.ParameterList() - nn.init.zeros_(self.lora_B.weight) - - def reset_parameters(self): - nn.Embedding.reset_parameters(self) - if hasattr(self, "lora_A"): - self.trainable_ids[:] = -1 - self.lora_A = nn.ParameterList() - nn.init.zeros_(self.lora_B.weight) - - def train(self, mode: bool = True): - nn.Embedding.train(self, mode) - self.lora_A.train(mode) - self.lora_B.train(mode) - if not mode and self.merge_weights and not self.merged: - self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.merged = True - elif self.merge_weights and self.merged: - self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.merged = False - - def eval(self): - nn.Embedding.eval(self) - self.lora_A.eval() - self.lora_B.eval() - - def forward(self, input_ids: torch.LongTensor): - result = nn.Embedding.forward(self, input_ids) - result += self.get_weights(input_ids) - return result diff --git a/train_lora.py b/train_lora.py index 737af58..dea58cf 100644 --- a/train_lora.py +++ b/train_lora.py @@ -15,7 +15,7 @@ import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from peft import LoraConfig, LoraModel +from peft import LoraConfig, get_peft_model # from diffusers.models.attention_processor import AttnProcessor import transformers @@ -731,7 +731,7 @@ def main(): lora_dropout=args.lora_dropout, bias=args.lora_bias, ) - unet = LoraModel(unet_config, unet) + unet = get_peft_model(unet, unet_config) text_encoder_config = LoraConfig( r=args.lora_text_encoder_r, @@ -740,7 +740,7 @@ def main(): lora_dropout=args.lora_text_encoder_dropout, bias=args.lora_text_encoder_bias, ) - text_encoder = LoraModel(text_encoder_config, text_encoder) + text_encoder = get_peft_model(text_encoder, text_encoder_config) vae.enable_slicing() @@ -1167,14 +1167,6 @@ def main(): group_labels.append("unet") if training_iter < args.train_text_encoder_cycles: - # if len(placeholder_tokens) != 0: - # params_to_optimize.append({ - # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), - # "lr": learning_rate_emb, - # "weight_decay": 0, - # }) - # group_labels.append("emb") - params_to_optimize.append({ "params": ( param diff --git a/training/attention_processor.py b/training/attention_processor.py new file mode 100644 index 0000000..4309bd4 --- /dev/null +++ b/training/attention_processor.py @@ -0,0 +1,47 @@ +from typing import Callable, Optional, Union + +import xformers +import xformers.ops + +from diffusers.models.attention_processor import Attention + + +class XFormersAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + query = query.to(key.dtype) + value = value.to(key.dtype) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states -- cgit v1.2.3-70-g09d2