From daba86ebbfa821f3c3227bcfbcbd532051e793e7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 12 May 2023 18:06:13 +0200 Subject: Fix for latest PEFT --- training/attention_processor.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 training/attention_processor.py (limited to 'training/attention_processor.py') diff --git a/training/attention_processor.py b/training/attention_processor.py new file mode 100644 index 0000000..4309bd4 --- /dev/null +++ b/training/attention_processor.py @@ -0,0 +1,47 @@ +from typing import Callable, Optional, Union + +import xformers +import xformers.ops + +from diffusers.models.attention_processor import Attention + + +class XFormersAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + query = query.to(key.dtype) + value = value.to(key.dtype) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states -- cgit v1.2.3-54-g00ecf