summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/attention_processor.py47
-rw-r--r--training/functional.py4
2 files changed, 2 insertions, 49 deletions
diff --git a/training/attention_processor.py b/training/attention_processor.py
deleted file mode 100644
index 4309bd4..0000000
--- a/training/attention_processor.py
+++ /dev/null
@@ -1,47 +0,0 @@
1from typing import Callable, Optional, Union
2
3import xformers
4import xformers.ops
5
6from diffusers.models.attention_processor import Attention
7
8
9class XFormersAttnProcessor:
10 def __init__(self, attention_op: Optional[Callable] = None):
11 self.attention_op = attention_op
12
13 def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
14 batch_size, sequence_length, _ = (
15 hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
16 )
17
18 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
19
20 query = attn.to_q(hidden_states)
21
22 if encoder_hidden_states is None:
23 encoder_hidden_states = hidden_states
24 elif attn.norm_cross:
25 encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
26
27 key = attn.to_k(encoder_hidden_states)
28 value = attn.to_v(encoder_hidden_states)
29
30 query = attn.head_to_batch_dim(query).contiguous()
31 key = attn.head_to_batch_dim(key).contiguous()
32 value = attn.head_to_batch_dim(value).contiguous()
33
34 query = query.to(key.dtype)
35 value = value.to(key.dtype)
36
37 hidden_states = xformers.ops.memory_efficient_attention(
38 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
39 )
40 hidden_states = hidden_states.to(query.dtype)
41 hidden_states = attn.batch_to_head_dim(hidden_states)
42
43 # linear proj
44 hidden_states = attn.to_out[0](hidden_states)
45 # dropout
46 hidden_states = attn.to_out[1](hidden_states)
47 return hidden_states
diff --git a/training/functional.py b/training/functional.py
index fd3f9f4..10560e5 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -710,8 +710,8 @@ def train(
710 vae = torch.compile(vae, backend='hidet') 710 vae = torch.compile(vae, backend='hidet')
711 711
712 if compile_unet: 712 if compile_unet:
713 unet = torch.compile(unet, backend='hidet') 713 # unet = torch.compile(unet, backend='hidet')
714 # unet = torch.compile(unet, mode="reduce-overhead") 714 unet = torch.compile(unet, mode="reduce-overhead")
715 715
716 callbacks = strategy.callbacks( 716 callbacks = strategy.callbacks(
717 accelerator=accelerator, 717 accelerator=accelerator,