summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
committerVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
commit1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch)
tree55ae4c0a660f5218c072d33f2896024b47c05c6b /training
parentDependency cleanup/upgrades (diff)
downloadtextual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip
Fix training
Diffstat (limited to 'training')
-rw-r--r--training/lora.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/training/lora.py b/training/lora.py
new file mode 100644
index 0000000..d8dc147
--- /dev/null
+++ b/training/lora.py
@@ -0,0 +1,27 @@
1import torch.nn as nn
2from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config
3
4
5class LoraAttention(ModelMixin, ConfigMixin):
6 @register_to_config
7 def __init__(self, in_features, out_features, r=4):
8 super().__init__()
9
10 if r > min(in_features, out_features):
11 raise ValueError(
12 f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
13 )
14
15 self.lora_down = nn.Linear(in_features, r, bias=False)
16 self.lora_up = nn.Linear(r, out_features, bias=False)
17 self.scale = 1.0
18
19 self.processor = XFormersCrossAttnProcessor()
20
21 nn.init.normal_(self.lora_down.weight, std=1 / r**2)
22 nn.init.zeros_(self.lora_up.weight)
23
24 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
25 hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number)
26 hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale
27 return hidden_states