From 1aace3e44dae0489130039714f67d980628c92ec Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 12:59:08 +0200 Subject: Avoid model recompilation due to varying prompt lengths --- training/attention_processor.py | 47 ----------------------------------------- training/functional.py | 4 ++-- 2 files changed, 2 insertions(+), 49 deletions(-) delete mode 100644 training/attention_processor.py (limited to 'training') diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable, Optional, Union - -import xformers -import xformers.ops - -from diffusers.models.attention_processor import Attention - - -class XFormersAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - query = query.to(key.dtype) - value = value.to(key.dtype) - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py @@ -710,8 +710,8 @@ def train( vae = torch.compile(vae, backend='hidet') if compile_unet: - unet = torch.compile(unet, backend='hidet') - # unet = torch.compile(unet, mode="reduce-overhead") + # unet = torch.compile(unet, backend='hidet') + unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, -- cgit v1.2.3-70-g09d2