summaryrefslogtreecommitdiffstats
path: root/training/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-21 16:48:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-21 16:48:25 +0100
commit47b2fba05ba0f2f5335321d94e0634a7291980c5 (patch)
treecceed035f7ff5eec2a067b9865ea3ac278634275 /training/lora.py
parentMoved common training code into separate module (diff)
downloadtextual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.tar.gz
textual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.tar.bz2
textual-inversion-diff-47b2fba05ba0f2f5335321d94e0634a7291980c5.zip
Some LoRA fixes (still broken)
Diffstat (limited to 'training/lora.py')
-rw-r--r--training/lora.py68
1 files changed, 55 insertions, 13 deletions
diff --git a/training/lora.py b/training/lora.py
index d8dc147..e1c0971 100644
--- a/training/lora.py
+++ b/training/lora.py
@@ -1,27 +1,69 @@
1import torch.nn as nn 1import torch.nn as nn
2from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config
3 2
3from diffusers import ModelMixin, ConfigMixin
4from diffusers.configuration_utils import register_to_config
5from diffusers.models.cross_attention import CrossAttention
6from diffusers.utils.import_utils import is_xformers_available
4 7
5class LoraAttention(ModelMixin, ConfigMixin): 8
9if is_xformers_available():
10 import xformers
11 import xformers.ops
12else:
13 xformers = None
14
15
16class LoraAttnProcessor(ModelMixin, ConfigMixin):
6 @register_to_config 17 @register_to_config
7 def __init__(self, in_features, out_features, r=4): 18 def __init__(
19 self,
20 cross_attention_dim,
21 inner_dim,
22 r: int = 4
23 ):
8 super().__init__() 24 super().__init__()
9 25
10 if r > min(in_features, out_features): 26 if r > min(cross_attention_dim, inner_dim):
11 raise ValueError( 27 raise ValueError(
12 f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" 28 f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}"
13 ) 29 )
14 30
15 self.lora_down = nn.Linear(in_features, r, bias=False) 31 self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False)
16 self.lora_up = nn.Linear(r, out_features, bias=False) 32 self.lora_k_up = nn.Linear(r, inner_dim, bias=False)
33
34 self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False)
35 self.lora_v_up = nn.Linear(r, inner_dim, bias=False)
36
17 self.scale = 1.0 37 self.scale = 1.0
18 38
19 self.processor = XFormersCrossAttnProcessor() 39 nn.init.normal_(self.lora_k_down.weight, std=1 / r**2)
40 nn.init.zeros_(self.lora_k_up.weight)
41
42 nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
43 nn.init.zeros_(self.lora_v_up.weight)
44
45 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
46 batch_size, sequence_length, _ = hidden_states.shape
47
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
49
50 query = attn.to_q(hidden_states)
51
52 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
53 key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
54 value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale
55
56 query = attn.head_to_batch_dim(query).contiguous()
57 key = attn.head_to_batch_dim(key).contiguous()
58 value = attn.head_to_batch_dim(value).contiguous()
59
60 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
61 hidden_states = hidden_states.to(query.dtype)
62 hidden_states = attn.batch_to_head_dim(hidden_states)
20 63
21 nn.init.normal_(self.lora_down.weight, std=1 / r**2) 64 # linear proj
22 nn.init.zeros_(self.lora_up.weight) 65 hidden_states = attn.to_out[0](hidden_states)
66 # dropout
67 hidden_states = attn.to_out[1](hidden_states)
23 68
24 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
25 hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number)
26 hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale
27 return hidden_states 69 return hidden_states