summaryrefslogtreecommitdiffstats
path: root/training/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
commit403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch)
tree385c62ef44cc33abc3c5d4b2084c376551137c5f /training/lora.py
parentDon't use vector_dropout by default (diff)
downloadtextual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip
Fixed reproducibility, more consistant validation
Diffstat (limited to 'training/lora.py')
-rw-r--r--training/lora.py92
1 files changed, 65 insertions, 27 deletions
diff --git a/training/lora.py b/training/lora.py
index e1c0971..3857d78 100644
--- a/training/lora.py
+++ b/training/lora.py
@@ -1,3 +1,4 @@
1import torch
1import torch.nn as nn 2import torch.nn as nn
2 3
3from diffusers import ModelMixin, ConfigMixin 4from diffusers import ModelMixin, ConfigMixin
@@ -13,56 +14,93 @@ else:
13 xformers = None 14 xformers = None
14 15
15 16
16class LoraAttnProcessor(ModelMixin, ConfigMixin): 17class LoRALinearLayer(nn.Module):
17 @register_to_config 18 def __init__(self, in_features, out_features, rank=4):
18 def __init__(
19 self,
20 cross_attention_dim,
21 inner_dim,
22 r: int = 4
23 ):
24 super().__init__() 19 super().__init__()
25 20
26 if r > min(cross_attention_dim, inner_dim): 21 if rank > min(in_features, out_features):
27 raise ValueError( 22 raise ValueError(
28 f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" 23 f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
29 ) 24 )
30 25
31 self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) 26 self.lora_down = nn.Linear(in_features, rank, bias=False)
32 self.lora_k_up = nn.Linear(r, inner_dim, bias=False) 27 self.lora_up = nn.Linear(rank, out_features, bias=False)
28 self.scale = 1.0
33 29
34 self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) 30 nn.init.normal_(self.lora_down.weight, std=1 / rank)
35 self.lora_v_up = nn.Linear(r, inner_dim, bias=False) 31 nn.init.zeros_(self.lora_up.weight)
36 32
37 self.scale = 1.0 33 def forward(self, hidden_states):
34 down_hidden_states = self.lora_down(hidden_states)
35 up_hidden_states = self.lora_up(down_hidden_states)
38 36
39 nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) 37 return up_hidden_states
40 nn.init.zeros_(self.lora_k_up.weight)
41 38
42 nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
43 nn.init.zeros_(self.lora_v_up.weight)
44 39
45 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 40class LoRACrossAttnProcessor(nn.Module):
46 batch_size, sequence_length, _ = hidden_states.shape 41 def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
42 super().__init__()
47 43
44 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
45 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
46 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
47 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
48
49 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
50 batch_size, sequence_length, _ = hidden_states.shape
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 51 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
49 52
50 query = attn.to_q(hidden_states) 53 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
54 query = attn.head_to_batch_dim(query)
51 55
52 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 56 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
53 key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
54 value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale
55 57
58 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
59 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
60
61 key = attn.head_to_batch_dim(key)
62 value = attn.head_to_batch_dim(value)
63
64 attention_probs = attn.get_attention_scores(query, key, attention_mask)
65 hidden_states = torch.bmm(attention_probs, value)
66 hidden_states = attn.batch_to_head_dim(hidden_states)
67
68 # linear proj
69 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
70 # dropout
71 hidden_states = attn.to_out[1](hidden_states)
72
73 return hidden_states
74
75
76class LoRAXFormersCrossAttnProcessor(nn.Module):
77 def __init__(self, hidden_size, cross_attention_dim, rank=4):
78 super().__init__()
79
80 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
81 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
82 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
83 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
84
85 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
86 batch_size, sequence_length, _ = hidden_states.shape
87 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
88
89 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
56 query = attn.head_to_batch_dim(query).contiguous() 90 query = attn.head_to_batch_dim(query).contiguous()
91
92 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
93
94 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
95 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
96
57 key = attn.head_to_batch_dim(key).contiguous() 97 key = attn.head_to_batch_dim(key).contiguous()
58 value = attn.head_to_batch_dim(value).contiguous() 98 value = attn.head_to_batch_dim(value).contiguous()
59 99
60 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 100 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
61 hidden_states = hidden_states.to(query.dtype)
62 hidden_states = attn.batch_to_head_dim(hidden_states)
63 101
64 # linear proj 102 # linear proj
65 hidden_states = attn.to_out[0](hidden_states) 103 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
66 # dropout 104 # dropout
67 hidden_states = attn.to_out[1](hidden_states) 105 hidden_states = attn.to_out[1](hidden_states)
68 106