summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-22 07:33:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-22 07:33:29 +0200
commit186a69104530610f8c2b924f79a04f941e5238c8 (patch)
treef04de211c4f33151c5163be222f7297087edb7d4
parentUpdate (diff)
downloadtextual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.tar.gz
textual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.tar.bz2
textual-inversion-diff-186a69104530610f8c2b924f79a04f941e5238c8.zip
Remove convnext
-rw-r--r--models/attention/control.py216
-rw-r--r--models/attention/hook.py63
-rw-r--r--models/attention/structured.py145
-rw-r--r--models/convnext/discriminator.py34
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py1
-rw-r--r--train_ti.py10
-rw-r--r--training/functional.py13
7 files changed, 0 insertions, 482 deletions
diff --git a/models/attention/control.py b/models/attention/control.py
deleted file mode 100644
index ec378c4..0000000
--- a/models/attention/control.py
+++ /dev/null
@@ -1,216 +0,0 @@
1import torch
2import abc
3
4
5class AttentionControl(abc.ABC):
6 def step_callback(self, x_t):
7 return x_t
8
9 def between_steps(self):
10 return
11
12 @property
13 def num_uncond_att_layers(self):
14 return self.num_att_layers if LOW_RESOURCE else 0
15
16 @abc.abstractmethod
17 def forward(self, attn, is_cross: bool, place_in_unet: str):
18 raise NotImplementedError
19
20 def __call__(self, attn, is_cross: bool, place_in_unet: str):
21 if self.cur_att_layer >= self.num_uncond_att_layers:
22 if LOW_RESOURCE:
23 attn = self.forward(attn, is_cross, place_in_unet)
24 else:
25 h = attn.shape[0]
26 attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
27 self.cur_att_layer += 1
28 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
29 self.cur_att_layer = 0
30 self.cur_step += 1
31 self.between_steps()
32 return attn
33
34 def reset(self):
35 self.cur_step = 0
36 self.cur_att_layer = 0
37
38 def __init__(self):
39 self.cur_step = 0
40 self.num_att_layers = -1
41 self.cur_att_layer = 0
42
43
44class EmptyControl(AttentionControl):
45 def forward(self, attn, is_cross: bool, place_in_unet: str):
46 return attn
47
48
49class AttentionStore(AttentionControl):
50 @staticmethod
51 def get_empty_store():
52 return {
53 "down_cross": [],
54 "mid_cross": [],
55 "up_cross": [],
56 "down_self": [],
57 "mid_self": [],
58 "up_self": [],
59 }
60
61 def forward(self, attn, is_cross: bool, place_in_unet: str):
62 key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
63 if attn.shape[1] <= 32**2: # avoid memory overhead
64 self.step_store[key].append(attn)
65 return attn
66
67 def between_steps(self):
68 if len(self.attention_store) == 0:
69 self.attention_store = self.step_store
70 else:
71 for key in self.attention_store:
72 for i in range(len(self.attention_store[key])):
73 self.attention_store[key][i] += self.step_store[key][i]
74 self.step_store = self.get_empty_store()
75
76 def get_average_attention(self):
77 average_attention = {
78 key: [item / self.cur_step for item in self.attention_store[key]]
79 for key in self.attention_store
80 }
81 return average_attention
82
83 def reset(self):
84 super(AttentionStore, self).reset()
85 self.step_store = self.get_empty_store()
86 self.attention_store = {}
87
88 def __init__(self):
89 super(AttentionStore, self).__init__()
90 self.step_store = self.get_empty_store()
91 self.attention_store = {}
92
93
94class AttentionControlEdit(AttentionStore, abc.ABC):
95 def step_callback(self, x_t):
96 if self.local_blend is not None:
97 x_t = self.local_blend(x_t, self.attention_store)
98 return x_t
99
100 def replace_self_attention(self, attn_base, att_replace):
101 if att_replace.shape[2] <= 16**2:
102 return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
103 else:
104 return att_replace
105
106 @abc.abstractmethod
107 def replace_cross_attention(self, attn_base, att_replace):
108 raise NotImplementedError
109
110 def forward(self, attn, is_cross: bool, place_in_unet: str):
111 super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
112 if is_cross or (
113 self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]
114 ):
115 h = attn.shape[0] // (self.batch_size)
116 attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
117 attn_base, attn_repalce = attn[0], attn[1:]
118 if is_cross:
119 alpha_words = self.cross_replace_alpha[self.cur_step]
120 attn_repalce_new = (
121 self.replace_cross_attention(attn_base, attn_repalce) * alpha_words
122 + (1 - alpha_words) * attn_repalce
123 )
124 attn[1:] = attn_repalce_new
125 else:
126 attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
127 attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
128 return attn
129
130 def __init__(
131 self,
132 prompts,
133 num_steps: int,
134 cross_replace_steps: Union[
135 float, Tuple[float, float], Dict[str, Tuple[float, float]]
136 ],
137 self_replace_steps: Union[float, Tuple[float, float]],
138 local_blend: Optional[LocalBlend],
139 ):
140 super(AttentionControlEdit, self).__init__()
141 self.batch_size = len(prompts)
142 self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(
143 prompts, num_steps, cross_replace_steps, tokenizer
144 ).to(device)
145 if type(self_replace_steps) is float:
146 self_replace_steps = 0, self_replace_steps
147 self.num_self_replace = int(num_steps * self_replace_steps[0]), int(
148 num_steps * self_replace_steps[1]
149 )
150 self.local_blend = local_blend
151
152
153class AttentionReplace(AttentionControlEdit):
154 def replace_cross_attention(self, attn_base, att_replace):
155 return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
156
157 def __init__(
158 self,
159 prompts,
160 num_steps: int,
161 cross_replace_steps: float,
162 self_replace_steps: float,
163 local_blend: Optional[LocalBlend] = None,
164 ):
165 super(AttentionReplace, self).__init__(
166 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
167 )
168 self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
169
170
171class AttentionRefine(AttentionControlEdit):
172 def replace_cross_attention(self, attn_base, att_replace):
173 attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
174 attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
175 return attn_replace
176
177 def __init__(
178 self,
179 prompts,
180 num_steps: int,
181 cross_replace_steps: float,
182 self_replace_steps: float,
183 local_blend: Optional[LocalBlend] = None,
184 ):
185 super(AttentionRefine, self).__init__(
186 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
187 )
188 self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
189 self.mapper, alphas = self.mapper.to(device), alphas.to(device)
190 self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
191
192
193class AttentionReweight(AttentionControlEdit):
194 def replace_cross_attention(self, attn_base, att_replace):
195 if self.prev_controller is not None:
196 attn_base = self.prev_controller.replace_cross_attention(
197 attn_base, att_replace
198 )
199 attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
200 return attn_replace
201
202 def __init__(
203 self,
204 prompts,
205 num_steps: int,
206 cross_replace_steps: float,
207 self_replace_steps: float,
208 equalizer,
209 local_blend: Optional[LocalBlend] = None,
210 controller: Optional[AttentionControlEdit] = None,
211 ):
212 super(AttentionReweight, self).__init__(
213 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
214 )
215 self.equalizer = equalizer.to(device)
216 self.prev_controller = controller
diff --git a/models/attention/hook.py b/models/attention/hook.py
deleted file mode 100644
index 6b5fb68..0000000
--- a/models/attention/hook.py
+++ /dev/null
@@ -1,63 +0,0 @@
1import torch
2
3
4try:
5 import xformers.ops
6
7 xformers._is_functorch_available = True
8 MEM_EFFICIENT_ATTN = True
9except ImportError:
10 print("[!] Not using xformers memory efficient attention.")
11 MEM_EFFICIENT_ATTN = False
12
13
14def register_attention_control(model, controller):
15 def ca_forward(self, place_in_unet):
16 def forward(x, context=None, mask=None):
17 batch_size, sequence_length, dim = x.shape
18 h = self.heads
19 q = self.to_q(x)
20 is_cross = context is not None
21 context = context if is_cross else x
22 k = self.to_k(context)
23 v = self.to_v(context)
24 q = self.reshape_heads_to_batch_dim(q)
25 k = self.reshape_heads_to_batch_dim(k)
26 v = self.reshape_heads_to_batch_dim(v)
27
28 sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
29
30 if mask is not None:
31 mask = mask.reshape(batch_size, -1)
32 max_neg_value = -torch.finfo(sim.dtype).max
33 mask = mask[:, None, :].repeat(h, 1, 1)
34 sim.masked_fill_(~mask, max_neg_value)
35
36 # attention, what we cannot get enough of
37 attn = sim.softmax(dim=-1)
38 attn = controller(attn, is_cross, place_in_unet)
39 out = torch.einsum("b i j, b j d -> b i d", attn, v)
40 out = self.reshape_batch_dim_to_heads(out)
41 return self.to_out(out)
42
43 return forward
44
45 def register_recr(net_, count, place_in_unet):
46 if net_.__class__.__name__ == "CrossAttention":
47 net_.forward = ca_forward(net_, place_in_unet)
48 return count + 1
49 elif hasattr(net_, "children"):
50 for net__ in net_.children():
51 count = register_recr(net__, count, place_in_unet)
52 return count
53
54 cross_att_count = 0
55 sub_nets = model.unet.named_children()
56 for net in sub_nets:
57 if "down" in net[0]:
58 cross_att_count += register_recr(net[1], 0, "down")
59 elif "up" in net[0]:
60 cross_att_count += register_recr(net[1], 0, "up")
61 elif "mid" in net[0]:
62 cross_att_count += register_recr(net[1], 0, "mid")
63 controller.num_att_layers = cross_att_count
diff --git a/models/attention/structured.py b/models/attention/structured.py
deleted file mode 100644
index 5bbbc06..0000000
--- a/models/attention/structured.py
+++ /dev/null
@@ -1,145 +0,0 @@
1import torch
2
3from .control import AttentionControl
4
5
6class StructuredAttentionControl(AttentionControl):
7 def forward(self, attn, is_cross: bool, place_in_unet: str):
8 return attn
9
10 def forward(self, x, context=None, mask=None):
11 h = self.heads
12
13 q = self.to_q(x)
14
15 if isinstance(context, list):
16 if self.struct_attn:
17 out = self.struct_qkv(q, context, mask)
18 else:
19 context = torch.cat(
20 [context[0], context[1]["k"][0]], dim=0
21 ) # use key tensor for context
22 out = self.normal_qkv(q, context, mask)
23 else:
24 context = default(context, x)
25 out = self.normal_qkv(q, context, mask)
26
27 return self.to_out(out)
28
29 def struct_qkv(self, q, context, mask):
30 """
31 context: list of [uc, list of conditional context]
32 """
33 uc_context = context[0]
34 context_k, context_v = context[1]["k"], context[1]["v"]
35
36 if isinstance(context_k, list) and isinstance(context_v, list):
37 out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
38 elif isinstance(context_k, torch.Tensor) and isinstance(
39 context_v, torch.Tensor
40 ):
41 out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask)
42 else:
43 raise NotImplementedError
44
45 return out
46
47 def multi_qkv(self, q, uc_context, context_k, context_v, mask):
48 h = self.heads
49
50 assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
51 true_bs = uc_context.size(0) * h
52
53 k_uc, v_uc = self.get_kv(uc_context)
54 k_c = [self.to_k(c_k) for c_k in context_k]
55 v_c = [self.to_v(c_v) for c_v in context_v]
56
57 q = rearrange(q, "b n (h d) -> (b h) n d", h=h)
58
59 k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h)
60 v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h)
61
62 k_c = [
63 rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c
64 ] # NOTE: modification point
65 v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c]
66
67 # get composition
68 sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale
69 sim_c = [
70 einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c
71 ]
72
73 attn_uc = sim_uc.softmax(dim=-1)
74 attn_c = [sim.softmax(dim=-1) for sim in sim_c]
75
76 # get uc output
77 out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc)
78
79 # get c output
80 if len(v_c) == 1:
81 out_c_collect = []
82 for attn in attn_c:
83 for v in v_c:
84 out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v))
85 out_c = sum(out_c_collect) / len(out_c_collect)
86 else:
87 out_c = sum(
88 [
89 einsum("b i j, b j d -> b i d", attn, v)
90 for attn, v in zip(attn_c, v_c)
91 ]
92 ) / len(v_c)
93
94 out = torch.cat([out_uc, out_c], dim=0)
95 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
96
97 return out
98
99 def normal_qkv(self, q, context, mask):
100 h = self.heads
101 k = self.to_k(context)
102 v = self.to_v(context)
103
104 q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
105
106 sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
107
108 if exists(mask):
109 mask = rearrange(mask, "b ... -> b (...)")
110 max_neg_value = -torch.finfo(sim.dtype).max
111 mask = repeat(mask, "b j -> (b h) () j", h=h)
112 sim.masked_fill_(~mask, max_neg_value)
113
114 # attention, what we cannot get enough of
115 attn = sim.softmax(dim=-1)
116
117 out = einsum("b i j, b j d -> b i d", attn, v)
118 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
119
120 return out
121
122 def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask):
123 h = self.heads
124 k = self.to_k(torch.cat([uc_context, context_k], dim=0))
125 v = self.to_v(torch.cat([uc_context, context_v], dim=0))
126
127 q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
128
129 sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
130
131 if exists(mask):
132 mask = rearrange(mask, "b ... -> b (...)")
133 max_neg_value = -torch.finfo(sim.dtype).max
134 mask = repeat(mask, "b j -> (b h) () j", h=h)
135 sim.masked_fill_(~mask, max_neg_value)
136
137 # attention, what we cannot get enough of
138 attn = sim.softmax(dim=-1)
139
140 out = einsum("b i j, b j d -> b i d", attn, v)
141 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
142 return out
143
144 def get_kv(self, context):
145 return self.to_k(context), self.to_v(context)
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py
deleted file mode 100644
index 5798bcf..0000000
--- a/models/convnext/discriminator.py
+++ /dev/null
@@ -1,34 +0,0 @@
1import torch
2from timm.models import ConvNeXt
3from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4
5from torch.nn import functional as F
6
7
8class ConvNeXtDiscriminator:
9 def __init__(self, model: ConvNeXt, input_size: int) -> None:
10 self.net = model
11
12 self.input_size = input_size
13
14 self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1)
15 self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1)
16
17 def get_score(self, img):
18 pred = self.get_all(img)
19 return torch.softmax(pred, dim=-1)[:, 1]
20
21 def get_all(self, img):
22 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype)
23 img_std = self.img_std.to(device=img.device, dtype=img.dtype)
24
25 img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std)
26
27 img = F.interpolate(
28 img,
29 size=(self.input_size, self.input_size),
30 mode="bicubic",
31 align_corners=True,
32 )
33 pred = self.net(img)
34 return pred
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 16b8456..98703d5 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -28,7 +28,6 @@ from diffusers.utils import logging, randn_tensor
28from transformers import CLIPTextModel, CLIPTokenizer 28from transformers import CLIPTextModel, CLIPTokenizer
29 29
30from models.clip.util import unify_input_ids, get_extended_embeddings 30from models.clip.util import unify_input_ids, get_extended_embeddings
31from util.noise import perlin_noise
32 31
33logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34 33
diff --git a/train_ti.py b/train_ti.py
index da0c03e..7d1ef19 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -698,16 +698,6 @@ def main():
698 unet.enable_gradient_checkpointing() 698 unet.enable_gradient_checkpointing()
699 text_encoder.gradient_checkpointing_enable() 699 text_encoder.gradient_checkpointing_enable()
700 700
701 # convnext = create_model(
702 # "convnext_tiny",
703 # pretrained=False,
704 # num_classes=3,
705 # drop_path_rate=0.0,
706 # )
707 # convnext.to(accelerator.device, dtype=weight_dtype)
708 # convnext.requires_grad_(False)
709 # convnext.eval()
710
711 if len(args.alias_tokens) != 0: 701 if len(args.alias_tokens) != 0:
712 alias_placeholder_tokens = args.alias_tokens[::2] 702 alias_placeholder_tokens = args.alias_tokens[::2]
713 alias_initializer_tokens = args.alias_tokens[1::2] 703 alias_initializer_tokens = args.alias_tokens[1::2]
diff --git a/training/functional.py b/training/functional.py
index 3c7848f..a3d1f08 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -29,11 +29,8 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
29from models.clip.embeddings import ManagedCLIPTextEmbeddings 29from models.clip.embeddings import ManagedCLIPTextEmbeddings
30from models.clip.util import get_extended_embeddings 30from models.clip.util import get_extended_embeddings
31from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
32from models.convnext.discriminator import ConvNeXtDiscriminator
33from training.util import AverageMeter 32from training.util import AverageMeter
34from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler 33from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
35from util.slerp import slerp
36from util.noise import perlin_noise
37 34
38 35
39def const(result=None): 36def const(result=None):
@@ -349,7 +346,6 @@ def loss_step(
349 prior_loss_weight: float, 346 prior_loss_weight: float,
350 seed: int, 347 seed: int,
351 input_pertubation: float, 348 input_pertubation: float,
352 disc: Optional[ConvNeXtDiscriminator],
353 min_snr_gamma: int, 349 min_snr_gamma: int,
354 step: int, 350 step: int,
355 batch: dict[str, Any], 351 batch: dict[str, Any],
@@ -449,13 +445,6 @@ def loss_step(
449 445
450 loss = loss.mean([1, 2, 3]) 446 loss = loss.mean([1, 2, 3])
451 447
452 if disc is not None:
453 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
454 rec_latent = rec_latent / vae.config.scaling_factor
455 rec_latent = rec_latent.to(dtype=vae.dtype)
456 rec = vae.decode(rec_latent, return_dict=False)[0]
457 loss = 1 - disc.get_score(rec)
458
459 if min_snr_gamma != 0: 448 if min_snr_gamma != 0:
460 snr = compute_snr(timesteps, noise_scheduler) 449 snr = compute_snr(timesteps, noise_scheduler)
461 mse_loss_weights = ( 450 mse_loss_weights = (
@@ -741,7 +730,6 @@ def train(
741 guidance_scale: float = 0.0, 730 guidance_scale: float = 0.0,
742 prior_loss_weight: float = 1.0, 731 prior_loss_weight: float = 1.0,
743 input_pertubation: float = 0.1, 732 input_pertubation: float = 0.1,
744 disc: Optional[ConvNeXtDiscriminator] = None,
745 schedule_sampler: Optional[ScheduleSampler] = None, 733 schedule_sampler: Optional[ScheduleSampler] = None,
746 min_snr_gamma: int = 5, 734 min_snr_gamma: int = 5,
747 avg_loss: AverageMeter = AverageMeter(), 735 avg_loss: AverageMeter = AverageMeter(),
@@ -803,7 +791,6 @@ def train(
803 prior_loss_weight, 791 prior_loss_weight,
804 seed, 792 seed,
805 input_pertubation, 793 input_pertubation,
806 disc,
807 min_snr_gamma, 794 min_snr_gamma,
808 ) 795 )
809 796