diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
| commit | 1aace3e44dae0489130039714f67d980628c92ec (patch) | |
| tree | 59a972b64bb3a3253e310055fc24381db68e8608 /training | |
| parent | Patch xformers to cast dtypes (diff) | |
| download | textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.gz textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.bz2 textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.zip | |
Avoid model recompilation due to varying prompt lengths
Diffstat (limited to 'training')
| -rw-r--r-- | training/attention_processor.py | 47 | ||||
| -rw-r--r-- | training/functional.py | 4 |
2 files changed, 2 insertions, 49 deletions
diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null | |||
| @@ -1,47 +0,0 @@ | |||
| 1 | from typing import Callable, Optional, Union | ||
| 2 | |||
| 3 | import xformers | ||
| 4 | import xformers.ops | ||
| 5 | |||
| 6 | from diffusers.models.attention_processor import Attention | ||
| 7 | |||
| 8 | |||
| 9 | class XFormersAttnProcessor: | ||
| 10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
| 11 | self.attention_op = attention_op | ||
| 12 | |||
| 13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
| 14 | batch_size, sequence_length, _ = ( | ||
| 15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
| 16 | ) | ||
| 17 | |||
| 18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
| 19 | |||
| 20 | query = attn.to_q(hidden_states) | ||
| 21 | |||
| 22 | if encoder_hidden_states is None: | ||
| 23 | encoder_hidden_states = hidden_states | ||
| 24 | elif attn.norm_cross: | ||
| 25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
| 26 | |||
| 27 | key = attn.to_k(encoder_hidden_states) | ||
| 28 | value = attn.to_v(encoder_hidden_states) | ||
| 29 | |||
| 30 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 31 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 32 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 33 | |||
| 34 | query = query.to(key.dtype) | ||
| 35 | value = value.to(key.dtype) | ||
| 36 | |||
| 37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
| 38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
| 39 | ) | ||
| 40 | hidden_states = hidden_states.to(query.dtype) | ||
| 41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 42 | |||
| 43 | # linear proj | ||
| 44 | hidden_states = attn.to_out[0](hidden_states) | ||
| 45 | # dropout | ||
| 46 | hidden_states = attn.to_out[1](hidden_states) | ||
| 47 | return hidden_states | ||
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -710,8 +710,8 @@ def train( | |||
| 710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
| 711 | 711 | ||
| 712 | if compile_unet: | 712 | if compile_unet: |
| 713 | unet = torch.compile(unet, backend='hidet') | 713 | # unet = torch.compile(unet, backend='hidet') |
| 714 | # unet = torch.compile(unet, mode="reduce-overhead") | 714 | unet = torch.compile(unet, mode="reduce-overhead") |
| 715 | 715 | ||
| 716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
| 717 | accelerator=accelerator, | 717 | accelerator=accelerator, |
