diff options
Diffstat (limited to 'training/lora.py')
| -rw-r--r-- | training/lora.py | 107 |
1 files changed, 0 insertions, 107 deletions
diff --git a/training/lora.py b/training/lora.py deleted file mode 100644 index 3857d78..0000000 --- a/training/lora.py +++ /dev/null | |||
| @@ -1,107 +0,0 @@ | |||
| 1 | import torch | ||
| 2 | import torch.nn as nn | ||
| 3 | |||
| 4 | from diffusers import ModelMixin, ConfigMixin | ||
| 5 | from diffusers.configuration_utils import register_to_config | ||
| 6 | from diffusers.models.cross_attention import CrossAttention | ||
| 7 | from diffusers.utils.import_utils import is_xformers_available | ||
| 8 | |||
| 9 | |||
| 10 | if is_xformers_available(): | ||
| 11 | import xformers | ||
| 12 | import xformers.ops | ||
| 13 | else: | ||
| 14 | xformers = None | ||
| 15 | |||
| 16 | |||
| 17 | class LoRALinearLayer(nn.Module): | ||
| 18 | def __init__(self, in_features, out_features, rank=4): | ||
| 19 | super().__init__() | ||
| 20 | |||
| 21 | if rank > min(in_features, out_features): | ||
| 22 | raise ValueError( | ||
| 23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" | ||
| 24 | ) | ||
| 25 | |||
| 26 | self.lora_down = nn.Linear(in_features, rank, bias=False) | ||
| 27 | self.lora_up = nn.Linear(rank, out_features, bias=False) | ||
| 28 | self.scale = 1.0 | ||
| 29 | |||
| 30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) | ||
| 31 | nn.init.zeros_(self.lora_up.weight) | ||
| 32 | |||
| 33 | def forward(self, hidden_states): | ||
| 34 | down_hidden_states = self.lora_down(hidden_states) | ||
| 35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
| 36 | |||
| 37 | return up_hidden_states | ||
| 38 | |||
| 39 | |||
| 40 | class LoRACrossAttnProcessor(nn.Module): | ||
| 41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
| 42 | super().__init__() | ||
| 43 | |||
| 44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 48 | |||
| 49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 50 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 52 | |||
| 53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
| 54 | query = attn.head_to_batch_dim(query) | ||
| 55 | |||
| 56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 57 | |||
| 58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 60 | |||
| 61 | key = attn.head_to_batch_dim(key) | ||
| 62 | value = attn.head_to_batch_dim(value) | ||
| 63 | |||
| 64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
| 65 | hidden_states = torch.bmm(attention_probs, value) | ||
| 66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 67 | |||
| 68 | # linear proj | ||
| 69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
| 70 | # dropout | ||
| 71 | hidden_states = attn.to_out[1](hidden_states) | ||
| 72 | |||
| 73 | return hidden_states | ||
| 74 | |||
| 75 | |||
| 76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
| 77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
| 78 | super().__init__() | ||
| 79 | |||
| 80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 84 | |||
| 85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 86 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 88 | |||
| 89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
| 90 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 91 | |||
| 92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 93 | |||
| 94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 96 | |||
| 97 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 98 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 99 | |||
| 100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
| 101 | |||
| 102 | # linear proj | ||
| 103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
| 104 | # dropout | ||
| 105 | hidden_states = attn.to_out[1](hidden_states) | ||
| 106 | |||
| 107 | return hidden_states | ||
