diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
| commit | 403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch) | |
| tree | 385c62ef44cc33abc3c5d4b2084c376551137c5f /training | |
| parent | Don't use vector_dropout by default (diff) | |
| download | textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2 textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip | |
Fixed reproducibility, more consistant validation
Diffstat (limited to 'training')
| -rw-r--r-- | training/lora.py | 92 | ||||
| -rw-r--r-- | training/lr.py | 6 |
2 files changed, 68 insertions, 30 deletions
diff --git a/training/lora.py b/training/lora.py index e1c0971..3857d78 100644 --- a/training/lora.py +++ b/training/lora.py | |||
| @@ -1,3 +1,4 @@ | |||
| 1 | import torch | ||
| 1 | import torch.nn as nn | 2 | import torch.nn as nn |
| 2 | 3 | ||
| 3 | from diffusers import ModelMixin, ConfigMixin | 4 | from diffusers import ModelMixin, ConfigMixin |
| @@ -13,56 +14,93 @@ else: | |||
| 13 | xformers = None | 14 | xformers = None |
| 14 | 15 | ||
| 15 | 16 | ||
| 16 | class LoraAttnProcessor(ModelMixin, ConfigMixin): | 17 | class LoRALinearLayer(nn.Module): |
| 17 | @register_to_config | 18 | def __init__(self, in_features, out_features, rank=4): |
| 18 | def __init__( | ||
| 19 | self, | ||
| 20 | cross_attention_dim, | ||
| 21 | inner_dim, | ||
| 22 | r: int = 4 | ||
| 23 | ): | ||
| 24 | super().__init__() | 19 | super().__init__() |
| 25 | 20 | ||
| 26 | if r > min(cross_attention_dim, inner_dim): | 21 | if rank > min(in_features, out_features): |
| 27 | raise ValueError( | 22 | raise ValueError( |
| 28 | f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" | 23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" |
| 29 | ) | 24 | ) |
| 30 | 25 | ||
| 31 | self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) | 26 | self.lora_down = nn.Linear(in_features, rank, bias=False) |
| 32 | self.lora_k_up = nn.Linear(r, inner_dim, bias=False) | 27 | self.lora_up = nn.Linear(rank, out_features, bias=False) |
| 28 | self.scale = 1.0 | ||
| 33 | 29 | ||
| 34 | self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) | 30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) |
| 35 | self.lora_v_up = nn.Linear(r, inner_dim, bias=False) | 31 | nn.init.zeros_(self.lora_up.weight) |
| 36 | 32 | ||
| 37 | self.scale = 1.0 | 33 | def forward(self, hidden_states): |
| 34 | down_hidden_states = self.lora_down(hidden_states) | ||
| 35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
| 38 | 36 | ||
| 39 | nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) | 37 | return up_hidden_states |
| 40 | nn.init.zeros_(self.lora_k_up.weight) | ||
| 41 | 38 | ||
| 42 | nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) | ||
| 43 | nn.init.zeros_(self.lora_v_up.weight) | ||
| 44 | 39 | ||
| 45 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | 40 | class LoRACrossAttnProcessor(nn.Module): |
| 46 | batch_size, sequence_length, _ = hidden_states.shape | 41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): |
| 42 | super().__init__() | ||
| 47 | 43 | ||
| 44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 48 | |||
| 49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 50 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | 51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
| 49 | 52 | ||
| 50 | query = attn.to_q(hidden_states) | 53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) |
| 54 | query = attn.head_to_batch_dim(query) | ||
| 51 | 55 | ||
| 52 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | 56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
| 53 | key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale | ||
| 54 | value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale | ||
| 55 | 57 | ||
| 58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 60 | |||
| 61 | key = attn.head_to_batch_dim(key) | ||
| 62 | value = attn.head_to_batch_dim(value) | ||
| 63 | |||
| 64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
| 65 | hidden_states = torch.bmm(attention_probs, value) | ||
| 66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 67 | |||
| 68 | # linear proj | ||
| 69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
| 70 | # dropout | ||
| 71 | hidden_states = attn.to_out[1](hidden_states) | ||
| 72 | |||
| 73 | return hidden_states | ||
| 74 | |||
| 75 | |||
| 76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
| 77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
| 78 | super().__init__() | ||
| 79 | |||
| 80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 84 | |||
| 85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 86 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 88 | |||
| 89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
| 56 | query = attn.head_to_batch_dim(query).contiguous() | 90 | query = attn.head_to_batch_dim(query).contiguous() |
| 91 | |||
| 92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 93 | |||
| 94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 96 | |||
| 57 | key = attn.head_to_batch_dim(key).contiguous() | 97 | key = attn.head_to_batch_dim(key).contiguous() |
| 58 | value = attn.head_to_batch_dim(value).contiguous() | 98 | value = attn.head_to_batch_dim(value).contiguous() |
| 59 | 99 | ||
| 60 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | 100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
| 61 | hidden_states = hidden_states.to(query.dtype) | ||
| 62 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 63 | 101 | ||
| 64 | # linear proj | 102 | # linear proj |
| 65 | hidden_states = attn.to_out[0](hidden_states) | 103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) |
| 66 | # dropout | 104 | # dropout |
| 67 | hidden_states = attn.to_out[1](hidden_states) | 105 | hidden_states = attn.to_out[1](hidden_states) |
| 68 | 106 | ||
diff --git a/training/lr.py b/training/lr.py index 37588b6..a3144ba 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | import copy |
| 3 | from typing import Callable | 3 | from typing import Callable, Any, Tuple, Union |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | 5 | ||
| 6 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
| @@ -24,7 +24,7 @@ class LRFinder(): | |||
| 24 | optimizer, | 24 | optimizer, |
| 25 | train_dataloader, | 25 | train_dataloader, |
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn, | 27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
| 29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
| 30 | ): | 30 | ): |
| @@ -108,7 +108,7 @@ class LRFinder(): | |||
| 108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
| 109 | break | 109 | break |
| 110 | 110 | ||
| 111 | loss, acc, bsz = self.loss_fn(batch) | 111 | loss, acc, bsz = self.loss_fn(batch, True) |
| 112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
| 113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
| 114 | 114 | ||
