diff options
author | Volpeon <git@volpeon.ink> | 2023-05-12 18:06:13 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-12 18:06:13 +0200 |
commit | daba86ebbfa821f3c3227bcfbcbd532051e793e7 (patch) | |
tree | 37b00a8216b6298004b75bd06ffc5c8ff76bce8e /training | |
parent | Update (diff) | |
download | textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.tar.gz textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.tar.bz2 textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.zip |
Fix for latest PEFT
Diffstat (limited to 'training')
-rw-r--r-- | training/attention_processor.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/training/attention_processor.py b/training/attention_processor.py new file mode 100644 index 0000000..4309bd4 --- /dev/null +++ b/training/attention_processor.py | |||
@@ -0,0 +1,47 @@ | |||
1 | from typing import Callable, Optional, Union | ||
2 | |||
3 | import xformers | ||
4 | import xformers.ops | ||
5 | |||
6 | from diffusers.models.attention_processor import Attention | ||
7 | |||
8 | |||
9 | class XFormersAttnProcessor: | ||
10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
11 | self.attention_op = attention_op | ||
12 | |||
13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
14 | batch_size, sequence_length, _ = ( | ||
15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
16 | ) | ||
17 | |||
18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
19 | |||
20 | query = attn.to_q(hidden_states) | ||
21 | |||
22 | if encoder_hidden_states is None: | ||
23 | encoder_hidden_states = hidden_states | ||
24 | elif attn.norm_cross: | ||
25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
26 | |||
27 | key = attn.to_k(encoder_hidden_states) | ||
28 | value = attn.to_v(encoder_hidden_states) | ||
29 | |||
30 | query = attn.head_to_batch_dim(query).contiguous() | ||
31 | key = attn.head_to_batch_dim(key).contiguous() | ||
32 | value = attn.head_to_batch_dim(value).contiguous() | ||
33 | |||
34 | query = query.to(key.dtype) | ||
35 | value = value.to(key.dtype) | ||
36 | |||
37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
39 | ) | ||
40 | hidden_states = hidden_states.to(query.dtype) | ||
41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
42 | |||
43 | # linear proj | ||
44 | hidden_states = attn.to_out[0](hidden_states) | ||
45 | # dropout | ||
46 | hidden_states = attn.to_out[1](hidden_states) | ||
47 | return hidden_states | ||