diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
| commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
| tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /training | |
| parent | Dependency cleanup/upgrades (diff) | |
| download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip | |
Fix training
Diffstat (limited to 'training')
| -rw-r--r-- | training/lora.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/training/lora.py b/training/lora.py new file mode 100644 index 0000000..d8dc147 --- /dev/null +++ b/training/lora.py | |||
| @@ -0,0 +1,27 @@ | |||
| 1 | import torch.nn as nn | ||
| 2 | from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config | ||
| 3 | |||
| 4 | |||
| 5 | class LoraAttention(ModelMixin, ConfigMixin): | ||
| 6 | @register_to_config | ||
| 7 | def __init__(self, in_features, out_features, r=4): | ||
| 8 | super().__init__() | ||
| 9 | |||
| 10 | if r > min(in_features, out_features): | ||
| 11 | raise ValueError( | ||
| 12 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | ||
| 13 | ) | ||
| 14 | |||
| 15 | self.lora_down = nn.Linear(in_features, r, bias=False) | ||
| 16 | self.lora_up = nn.Linear(r, out_features, bias=False) | ||
| 17 | self.scale = 1.0 | ||
| 18 | |||
| 19 | self.processor = XFormersCrossAttnProcessor() | ||
| 20 | |||
| 21 | nn.init.normal_(self.lora_down.weight, std=1 / r**2) | ||
| 22 | nn.init.zeros_(self.lora_up.weight) | ||
| 23 | |||
| 24 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): | ||
| 25 | hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) | ||
| 26 | hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale | ||
| 27 | return hidden_states | ||
