diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/attention_processor.py | 47 | ||||
-rw-r--r-- | training/functional.py | 4 |
2 files changed, 2 insertions, 49 deletions
diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null | |||
@@ -1,47 +0,0 @@ | |||
1 | from typing import Callable, Optional, Union | ||
2 | |||
3 | import xformers | ||
4 | import xformers.ops | ||
5 | |||
6 | from diffusers.models.attention_processor import Attention | ||
7 | |||
8 | |||
9 | class XFormersAttnProcessor: | ||
10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
11 | self.attention_op = attention_op | ||
12 | |||
13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
14 | batch_size, sequence_length, _ = ( | ||
15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
16 | ) | ||
17 | |||
18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
19 | |||
20 | query = attn.to_q(hidden_states) | ||
21 | |||
22 | if encoder_hidden_states is None: | ||
23 | encoder_hidden_states = hidden_states | ||
24 | elif attn.norm_cross: | ||
25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
26 | |||
27 | key = attn.to_k(encoder_hidden_states) | ||
28 | value = attn.to_v(encoder_hidden_states) | ||
29 | |||
30 | query = attn.head_to_batch_dim(query).contiguous() | ||
31 | key = attn.head_to_batch_dim(key).contiguous() | ||
32 | value = attn.head_to_batch_dim(value).contiguous() | ||
33 | |||
34 | query = query.to(key.dtype) | ||
35 | value = value.to(key.dtype) | ||
36 | |||
37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
39 | ) | ||
40 | hidden_states = hidden_states.to(query.dtype) | ||
41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
42 | |||
43 | # linear proj | ||
44 | hidden_states = attn.to_out[0](hidden_states) | ||
45 | # dropout | ||
46 | hidden_states = attn.to_out[1](hidden_states) | ||
47 | return hidden_states | ||
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -710,8 +710,8 @@ def train( | |||
710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
711 | 711 | ||
712 | if compile_unet: | 712 | if compile_unet: |
713 | unet = torch.compile(unet, backend='hidet') | 713 | # unet = torch.compile(unet, backend='hidet') |
714 | # unet = torch.compile(unet, mode="reduce-overhead") | 714 | unet = torch.compile(unet, mode="reduce-overhead") |
715 | 715 | ||
716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
717 | accelerator=accelerator, | 717 | accelerator=accelerator, |