diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
commit | 1aace3e44dae0489130039714f67d980628c92ec (patch) | |
tree | 59a972b64bb3a3253e310055fc24381db68e8608 | |
parent | Patch xformers to cast dtypes (diff) | |
download | textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.gz textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.bz2 textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.zip |
Avoid model recompilation due to varying prompt lengths
-rw-r--r-- | data/csv.py | 31 | ||||
-rw-r--r-- | models/clip/util.py | 23 | ||||
-rw-r--r-- | train_lora.py | 3 | ||||
-rw-r--r-- | training/attention_processor.py | 47 | ||||
-rw-r--r-- | training/functional.py | 4 |
5 files changed, 43 insertions, 65 deletions
diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -100,7 +100,14 @@ def generate_buckets( | |||
100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
101 | 101 | ||
102 | 102 | ||
103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): | 103 | def collate_fn( |
104 | dtype: torch.dtype, | ||
105 | tokenizer: CLIPTokenizer, | ||
106 | max_token_id_length: Optional[int], | ||
107 | with_guidance: bool, | ||
108 | with_prior_preservation: bool, | ||
109 | examples | ||
110 | ): | ||
104 | prompt_ids = [example["prompt_ids"] for example in examples] | 111 | prompt_ids = [example["prompt_ids"] for example in examples] |
105 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
106 | 113 | ||
@@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
115 | pixel_values = torch.stack(pixel_values) | 122 | pixel_values = torch.stack(pixel_values) |
116 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) | 123 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) |
117 | 124 | ||
118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
120 | inputs = unify_input_ids(tokenizer, input_ids) | 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
121 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) |
122 | 129 | ||
123 | batch = { | 130 | batch = { |
124 | "prompt_ids": prompts.input_ids, | 131 | "prompt_ids": prompts.input_ids, |
@@ -176,6 +183,7 @@ class VlpnDataModule(): | |||
176 | batch_size: int, | 183 | batch_size: int, |
177 | data_file: str, | 184 | data_file: str, |
178 | tokenizer: CLIPTokenizer, | 185 | tokenizer: CLIPTokenizer, |
186 | constant_prompt_length: bool = False, | ||
179 | class_subdir: str = "cls", | 187 | class_subdir: str = "cls", |
180 | with_guidance: bool = False, | 188 | with_guidance: bool = False, |
181 | num_class_images: int = 1, | 189 | num_class_images: int = 1, |
@@ -212,6 +220,9 @@ class VlpnDataModule(): | |||
212 | self.num_class_images = num_class_images | 220 | self.num_class_images = num_class_images |
213 | self.with_guidance = with_guidance | 221 | self.with_guidance = with_guidance |
214 | 222 | ||
223 | self.constant_prompt_length = constant_prompt_length | ||
224 | self.max_token_id_length = None | ||
225 | |||
215 | self.tokenizer = tokenizer | 226 | self.tokenizer = tokenizer |
216 | self.size = size | 227 | self.size = size |
217 | self.num_buckets = num_buckets | 228 | self.num_buckets = num_buckets |
@@ -301,14 +312,20 @@ class VlpnDataModule(): | |||
301 | items = self.prepare_items(template, expansions, items) | 312 | items = self.prepare_items(template, expansions, items) |
302 | items = self.filter_items(items) | 313 | items = self.filter_items(items) |
303 | self.npgenerator.shuffle(items) | 314 | self.npgenerator.shuffle(items) |
315 | |||
316 | if self.constant_prompt_length: | ||
317 | all_input_ids = unify_input_ids( | ||
318 | self.tokenizer, | ||
319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | ||
320 | ).input_ids | ||
321 | self.max_token_id_length = all_input_ids.shape[1] | ||
304 | 322 | ||
305 | num_images = len(items) | 323 | num_images = len(items) |
306 | |||
307 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
308 | train_set_size = max(num_images - valid_set_size, 1) | 325 | train_set_size = max(num_images - valid_set_size, 1) |
309 | valid_set_size = num_images - train_set_size | 326 | valid_set_size = num_images - train_set_size |
310 | 327 | ||
311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) |
312 | 329 | ||
313 | if valid_set_size == 0: | 330 | if valid_set_size == 0: |
314 | data_train, data_val = items, items | 331 | data_train, data_val = items, items |
diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -5,14 +5,21 @@ import torch | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
6 | 6 | ||
7 | 7 | ||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): |
9 | return tokenizer.pad( | 9 | if max_length is None: |
10 | {"input_ids": input_ids}, | 10 | return tokenizer.pad( |
11 | padding=True, | 11 | {"input_ids": input_ids}, |
12 | pad_to_multiple_of=tokenizer.model_max_length, | 12 | padding=True, |
13 | return_tensors="pt" | 13 | pad_to_multiple_of=tokenizer.model_max_length, |
14 | ) | 14 | return_tensors="pt" |
15 | 15 | ) | |
16 | else: | ||
17 | return tokenizer.pad( | ||
18 | {"input_ids": input_ids}, | ||
19 | padding="max_length", | ||
20 | max_length=max_length, | ||
21 | return_tensors="pt" | ||
22 | ) | ||
16 | 23 | ||
17 | def get_extended_embeddings( | 24 | def get_extended_embeddings( |
18 | text_encoder: CLIPTextModel, | 25 | text_encoder: CLIPTextModel, |
diff --git a/train_lora.py b/train_lora.py index a58bef7..12d7e72 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
50 | 50 | ||
51 | torch._dynamo.config.log_level = logging.WARNING | 51 | torch._dynamo.config.log_level = logging.WARNING |
52 | torch._dynamo.config.suppress_errors = True | 52 | # torch._dynamo.config.suppress_errors = True |
53 | 53 | ||
54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
@@ -992,6 +992,7 @@ def main(): | |||
992 | VlpnDataModule, | 992 | VlpnDataModule, |
993 | data_file=args.train_data_file, | 993 | data_file=args.train_data_file, |
994 | tokenizer=tokenizer, | 994 | tokenizer=tokenizer, |
995 | constant_prompt_length=args.compile_unet, | ||
995 | class_subdir=args.class_image_dir, | 996 | class_subdir=args.class_image_dir, |
996 | with_guidance=args.guidance_scale != 0, | 997 | with_guidance=args.guidance_scale != 0, |
997 | num_class_images=args.num_class_images, | 998 | num_class_images=args.num_class_images, |
diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null | |||
@@ -1,47 +0,0 @@ | |||
1 | from typing import Callable, Optional, Union | ||
2 | |||
3 | import xformers | ||
4 | import xformers.ops | ||
5 | |||
6 | from diffusers.models.attention_processor import Attention | ||
7 | |||
8 | |||
9 | class XFormersAttnProcessor: | ||
10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
11 | self.attention_op = attention_op | ||
12 | |||
13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
14 | batch_size, sequence_length, _ = ( | ||
15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
16 | ) | ||
17 | |||
18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
19 | |||
20 | query = attn.to_q(hidden_states) | ||
21 | |||
22 | if encoder_hidden_states is None: | ||
23 | encoder_hidden_states = hidden_states | ||
24 | elif attn.norm_cross: | ||
25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
26 | |||
27 | key = attn.to_k(encoder_hidden_states) | ||
28 | value = attn.to_v(encoder_hidden_states) | ||
29 | |||
30 | query = attn.head_to_batch_dim(query).contiguous() | ||
31 | key = attn.head_to_batch_dim(key).contiguous() | ||
32 | value = attn.head_to_batch_dim(value).contiguous() | ||
33 | |||
34 | query = query.to(key.dtype) | ||
35 | value = value.to(key.dtype) | ||
36 | |||
37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
39 | ) | ||
40 | hidden_states = hidden_states.to(query.dtype) | ||
41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
42 | |||
43 | # linear proj | ||
44 | hidden_states = attn.to_out[0](hidden_states) | ||
45 | # dropout | ||
46 | hidden_states = attn.to_out[1](hidden_states) | ||
47 | return hidden_states | ||
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -710,8 +710,8 @@ def train( | |||
710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
711 | 711 | ||
712 | if compile_unet: | 712 | if compile_unet: |
713 | unet = torch.compile(unet, backend='hidet') | 713 | # unet = torch.compile(unet, backend='hidet') |
714 | # unet = torch.compile(unet, mode="reduce-overhead") | 714 | unet = torch.compile(unet, mode="reduce-overhead") |
715 | 715 | ||
716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
717 | accelerator=accelerator, | 717 | accelerator=accelerator, |