summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/attention_processor.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/training/attention_processor.py b/training/attention_processor.py
new file mode 100644
index 0000000..4309bd4
--- /dev/null
+++ b/training/attention_processor.py
@@ -0,0 +1,47 @@
1from typing import Callable, Optional, Union
2
3import xformers
4import xformers.ops
5
6from diffusers.models.attention_processor import Attention
7
8
9class XFormersAttnProcessor:
10 def __init__(self, attention_op: Optional[Callable] = None):
11 self.attention_op = attention_op
12
13 def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
14 batch_size, sequence_length, _ = (
15 hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
16 )
17
18 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
19
20 query = attn.to_q(hidden_states)
21
22 if encoder_hidden_states is None:
23 encoder_hidden_states = hidden_states
24 elif attn.norm_cross:
25 encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
26
27 key = attn.to_k(encoder_hidden_states)
28 value = attn.to_v(encoder_hidden_states)
29
30 query = attn.head_to_batch_dim(query).contiguous()
31 key = attn.head_to_batch_dim(key).contiguous()
32 value = attn.head_to_batch_dim(value).contiguous()
33
34 query = query.to(key.dtype)
35 value = value.to(key.dtype)
36
37 hidden_states = xformers.ops.memory_efficient_attention(
38 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
39 )
40 hidden_states = hidden_states.to(query.dtype)
41 hidden_states = attn.batch_to_head_dim(hidden_states)
42
43 # linear proj
44 hidden_states = attn.to_out[0](hidden_states)
45 # dropout
46 hidden_states = attn.to_out[1](hidden_states)
47 return hidden_states