diff options
author | Volpeon <git@volpeon.ink> | 2023-06-22 07:33:29 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-22 07:33:29 +0200 |
commit | 186a69104530610f8c2b924f79a04f941e5238c8 (patch) | |
tree | f04de211c4f33151c5163be222f7297087edb7d4 /models/attention/hook.py | |
parent | Update (diff) | |
download | textual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.tar.gz textual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.tar.bz2 textual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.zip |
Remove convnext
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r-- | models/attention/hook.py | 63 |
1 files changed, 0 insertions, 63 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py deleted file mode 100644 index 6b5fb68..0000000 --- a/models/attention/hook.py +++ /dev/null | |||
@@ -1,63 +0,0 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | try: | ||
5 | import xformers.ops | ||
6 | |||
7 | xformers._is_functorch_available = True | ||
8 | MEM_EFFICIENT_ATTN = True | ||
9 | except ImportError: | ||
10 | print("[!] Not using xformers memory efficient attention.") | ||
11 | MEM_EFFICIENT_ATTN = False | ||
12 | |||
13 | |||
14 | def register_attention_control(model, controller): | ||
15 | def ca_forward(self, place_in_unet): | ||
16 | def forward(x, context=None, mask=None): | ||
17 | batch_size, sequence_length, dim = x.shape | ||
18 | h = self.heads | ||
19 | q = self.to_q(x) | ||
20 | is_cross = context is not None | ||
21 | context = context if is_cross else x | ||
22 | k = self.to_k(context) | ||
23 | v = self.to_v(context) | ||
24 | q = self.reshape_heads_to_batch_dim(q) | ||
25 | k = self.reshape_heads_to_batch_dim(k) | ||
26 | v = self.reshape_heads_to_batch_dim(v) | ||
27 | |||
28 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | ||
29 | |||
30 | if mask is not None: | ||
31 | mask = mask.reshape(batch_size, -1) | ||
32 | max_neg_value = -torch.finfo(sim.dtype).max | ||
33 | mask = mask[:, None, :].repeat(h, 1, 1) | ||
34 | sim.masked_fill_(~mask, max_neg_value) | ||
35 | |||
36 | # attention, what we cannot get enough of | ||
37 | attn = sim.softmax(dim=-1) | ||
38 | attn = controller(attn, is_cross, place_in_unet) | ||
39 | out = torch.einsum("b i j, b j d -> b i d", attn, v) | ||
40 | out = self.reshape_batch_dim_to_heads(out) | ||
41 | return self.to_out(out) | ||
42 | |||
43 | return forward | ||
44 | |||
45 | def register_recr(net_, count, place_in_unet): | ||
46 | if net_.__class__.__name__ == "CrossAttention": | ||
47 | net_.forward = ca_forward(net_, place_in_unet) | ||
48 | return count + 1 | ||
49 | elif hasattr(net_, "children"): | ||
50 | for net__ in net_.children(): | ||
51 | count = register_recr(net__, count, place_in_unet) | ||
52 | return count | ||
53 | |||
54 | cross_att_count = 0 | ||
55 | sub_nets = model.unet.named_children() | ||
56 | for net in sub_nets: | ||
57 | if "down" in net[0]: | ||
58 | cross_att_count += register_recr(net[1], 0, "down") | ||
59 | elif "up" in net[0]: | ||
60 | cross_att_count += register_recr(net[1], 0, "up") | ||
61 | elif "mid" in net[0]: | ||
62 | cross_att_count += register_recr(net[1], 0, "mid") | ||
63 | controller.num_att_layers = cross_att_count | ||