diff options
author | Volpeon <git@volpeon.ink> | 2022-12-21 16:48:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-21 16:48:25 +0100 |
commit | 47b2fba05ba0f2f5335321d94e0634a7291980c5 (patch) | |
tree | cceed035f7ff5eec2a067b9865ea3ac278634275 /training | |
parent | Moved common training code into separate module (diff) | |
download | textual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.tar.gz textual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.tar.bz2 textual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.zip |
Some LoRA fixes (still broken)
Diffstat (limited to 'training')
-rw-r--r-- | training/lora.py | 68 |
1 files changed, 55 insertions, 13 deletions
diff --git a/training/lora.py b/training/lora.py index d8dc147..e1c0971 100644 --- a/training/lora.py +++ b/training/lora.py | |||
@@ -1,27 +1,69 @@ | |||
1 | import torch.nn as nn | 1 | import torch.nn as nn |
2 | from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config | ||
3 | 2 | ||
3 | from diffusers import ModelMixin, ConfigMixin | ||
4 | from diffusers.configuration_utils import register_to_config | ||
5 | from diffusers.models.cross_attention import CrossAttention | ||
6 | from diffusers.utils.import_utils import is_xformers_available | ||
4 | 7 | ||
5 | class LoraAttention(ModelMixin, ConfigMixin): | 8 | |
9 | if is_xformers_available(): | ||
10 | import xformers | ||
11 | import xformers.ops | ||
12 | else: | ||
13 | xformers = None | ||
14 | |||
15 | |||
16 | class LoraAttnProcessor(ModelMixin, ConfigMixin): | ||
6 | @register_to_config | 17 | @register_to_config |
7 | def __init__(self, in_features, out_features, r=4): | 18 | def __init__( |
19 | self, | ||
20 | cross_attention_dim, | ||
21 | inner_dim, | ||
22 | r: int = 4 | ||
23 | ): | ||
8 | super().__init__() | 24 | super().__init__() |
9 | 25 | ||
10 | if r > min(in_features, out_features): | 26 | if r > min(cross_attention_dim, inner_dim): |
11 | raise ValueError( | 27 | raise ValueError( |
12 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | 28 | f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" |
13 | ) | 29 | ) |
14 | 30 | ||
15 | self.lora_down = nn.Linear(in_features, r, bias=False) | 31 | self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) |
16 | self.lora_up = nn.Linear(r, out_features, bias=False) | 32 | self.lora_k_up = nn.Linear(r, inner_dim, bias=False) |
33 | |||
34 | self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) | ||
35 | self.lora_v_up = nn.Linear(r, inner_dim, bias=False) | ||
36 | |||
17 | self.scale = 1.0 | 37 | self.scale = 1.0 |
18 | 38 | ||
19 | self.processor = XFormersCrossAttnProcessor() | 39 | nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) |
40 | nn.init.zeros_(self.lora_k_up.weight) | ||
41 | |||
42 | nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) | ||
43 | nn.init.zeros_(self.lora_v_up.weight) | ||
44 | |||
45 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
46 | batch_size, sequence_length, _ = hidden_states.shape | ||
47 | |||
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
49 | |||
50 | query = attn.to_q(hidden_states) | ||
51 | |||
52 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
53 | key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale | ||
54 | value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale | ||
55 | |||
56 | query = attn.head_to_batch_dim(query).contiguous() | ||
57 | key = attn.head_to_batch_dim(key).contiguous() | ||
58 | value = attn.head_to_batch_dim(value).contiguous() | ||
59 | |||
60 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
61 | hidden_states = hidden_states.to(query.dtype) | ||
62 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
20 | 63 | ||
21 | nn.init.normal_(self.lora_down.weight, std=1 / r**2) | 64 | # linear proj |
22 | nn.init.zeros_(self.lora_up.weight) | 65 | hidden_states = attn.to_out[0](hidden_states) |
66 | # dropout | ||
67 | hidden_states = attn.to_out[1](hidden_states) | ||
23 | 68 | ||
24 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): | ||
25 | hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) | ||
26 | hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale | ||
27 | return hidden_states | 69 | return hidden_states |