diff options
| -rw-r--r-- | data/csv.py | 31 | ||||
| -rw-r--r-- | models/clip/util.py | 23 | ||||
| -rw-r--r-- | train_lora.py | 3 | ||||
| -rw-r--r-- | training/attention_processor.py | 47 | ||||
| -rw-r--r-- | training/functional.py | 4 |
5 files changed, 43 insertions, 65 deletions
diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -100,7 +100,14 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): | 103 | def collate_fn( |
| 104 | dtype: torch.dtype, | ||
| 105 | tokenizer: CLIPTokenizer, | ||
| 106 | max_token_id_length: Optional[int], | ||
| 107 | with_guidance: bool, | ||
| 108 | with_prior_preservation: bool, | ||
| 109 | examples | ||
| 110 | ): | ||
| 104 | prompt_ids = [example["prompt_ids"] for example in examples] | 111 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 105 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 106 | 113 | ||
| @@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
| 115 | pixel_values = torch.stack(pixel_values) | 122 | pixel_values = torch.stack(pixel_values) |
| 116 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) | 123 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) |
| 117 | 124 | ||
| 118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
| 119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
| 120 | inputs = unify_input_ids(tokenizer, input_ids) | 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
| 121 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) |
| 122 | 129 | ||
| 123 | batch = { | 130 | batch = { |
| 124 | "prompt_ids": prompts.input_ids, | 131 | "prompt_ids": prompts.input_ids, |
| @@ -176,6 +183,7 @@ class VlpnDataModule(): | |||
| 176 | batch_size: int, | 183 | batch_size: int, |
| 177 | data_file: str, | 184 | data_file: str, |
| 178 | tokenizer: CLIPTokenizer, | 185 | tokenizer: CLIPTokenizer, |
| 186 | constant_prompt_length: bool = False, | ||
| 179 | class_subdir: str = "cls", | 187 | class_subdir: str = "cls", |
| 180 | with_guidance: bool = False, | 188 | with_guidance: bool = False, |
| 181 | num_class_images: int = 1, | 189 | num_class_images: int = 1, |
| @@ -212,6 +220,9 @@ class VlpnDataModule(): | |||
| 212 | self.num_class_images = num_class_images | 220 | self.num_class_images = num_class_images |
| 213 | self.with_guidance = with_guidance | 221 | self.with_guidance = with_guidance |
| 214 | 222 | ||
| 223 | self.constant_prompt_length = constant_prompt_length | ||
| 224 | self.max_token_id_length = None | ||
| 225 | |||
| 215 | self.tokenizer = tokenizer | 226 | self.tokenizer = tokenizer |
| 216 | self.size = size | 227 | self.size = size |
| 217 | self.num_buckets = num_buckets | 228 | self.num_buckets = num_buckets |
| @@ -301,14 +312,20 @@ class VlpnDataModule(): | |||
| 301 | items = self.prepare_items(template, expansions, items) | 312 | items = self.prepare_items(template, expansions, items) |
| 302 | items = self.filter_items(items) | 313 | items = self.filter_items(items) |
| 303 | self.npgenerator.shuffle(items) | 314 | self.npgenerator.shuffle(items) |
| 315 | |||
| 316 | if self.constant_prompt_length: | ||
| 317 | all_input_ids = unify_input_ids( | ||
| 318 | self.tokenizer, | ||
| 319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | ||
| 320 | ).input_ids | ||
| 321 | self.max_token_id_length = all_input_ids.shape[1] | ||
| 304 | 322 | ||
| 305 | num_images = len(items) | 323 | num_images = len(items) |
| 306 | |||
| 307 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| 308 | train_set_size = max(num_images - valid_set_size, 1) | 325 | train_set_size = max(num_images - valid_set_size, 1) |
| 309 | valid_set_size = num_images - train_set_size | 326 | valid_set_size = num_images - train_set_size |
| 310 | 327 | ||
| 311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) |
| 312 | 329 | ||
| 313 | if valid_set_size == 0: | 330 | if valid_set_size == 0: |
| 314 | data_train, data_val = items, items | 331 | data_train, data_val = items, items |
diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
| @@ -5,14 +5,21 @@ import torch | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): |
| 9 | return tokenizer.pad( | 9 | if max_length is None: |
| 10 | {"input_ids": input_ids}, | 10 | return tokenizer.pad( |
| 11 | padding=True, | 11 | {"input_ids": input_ids}, |
| 12 | pad_to_multiple_of=tokenizer.model_max_length, | 12 | padding=True, |
| 13 | return_tensors="pt" | 13 | pad_to_multiple_of=tokenizer.model_max_length, |
| 14 | ) | 14 | return_tensors="pt" |
| 15 | 15 | ) | |
| 16 | else: | ||
| 17 | return tokenizer.pad( | ||
| 18 | {"input_ids": input_ids}, | ||
| 19 | padding="max_length", | ||
| 20 | max_length=max_length, | ||
| 21 | return_tensors="pt" | ||
| 22 | ) | ||
| 16 | 23 | ||
| 17 | def get_extended_embeddings( | 24 | def get_extended_embeddings( |
| 18 | text_encoder: CLIPTextModel, | 25 | text_encoder: CLIPTextModel, |
diff --git a/train_lora.py b/train_lora.py index a58bef7..12d7e72 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
| 50 | 50 | ||
| 51 | torch._dynamo.config.log_level = logging.WARNING | 51 | torch._dynamo.config.log_level = logging.WARNING |
| 52 | torch._dynamo.config.suppress_errors = True | 52 | # torch._dynamo.config.suppress_errors = True |
| 53 | 53 | ||
| 54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
| @@ -992,6 +992,7 @@ def main(): | |||
| 992 | VlpnDataModule, | 992 | VlpnDataModule, |
| 993 | data_file=args.train_data_file, | 993 | data_file=args.train_data_file, |
| 994 | tokenizer=tokenizer, | 994 | tokenizer=tokenizer, |
| 995 | constant_prompt_length=args.compile_unet, | ||
| 995 | class_subdir=args.class_image_dir, | 996 | class_subdir=args.class_image_dir, |
| 996 | with_guidance=args.guidance_scale != 0, | 997 | with_guidance=args.guidance_scale != 0, |
| 997 | num_class_images=args.num_class_images, | 998 | num_class_images=args.num_class_images, |
diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null | |||
| @@ -1,47 +0,0 @@ | |||
| 1 | from typing import Callable, Optional, Union | ||
| 2 | |||
| 3 | import xformers | ||
| 4 | import xformers.ops | ||
| 5 | |||
| 6 | from diffusers.models.attention_processor import Attention | ||
| 7 | |||
| 8 | |||
| 9 | class XFormersAttnProcessor: | ||
| 10 | def __init__(self, attention_op: Optional[Callable] = None): | ||
| 11 | self.attention_op = attention_op | ||
| 12 | |||
| 13 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
| 14 | batch_size, sequence_length, _ = ( | ||
| 15 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
| 16 | ) | ||
| 17 | |||
| 18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
| 19 | |||
| 20 | query = attn.to_q(hidden_states) | ||
| 21 | |||
| 22 | if encoder_hidden_states is None: | ||
| 23 | encoder_hidden_states = hidden_states | ||
| 24 | elif attn.norm_cross: | ||
| 25 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
| 26 | |||
| 27 | key = attn.to_k(encoder_hidden_states) | ||
| 28 | value = attn.to_v(encoder_hidden_states) | ||
| 29 | |||
| 30 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 31 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 32 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 33 | |||
| 34 | query = query.to(key.dtype) | ||
| 35 | value = value.to(key.dtype) | ||
| 36 | |||
| 37 | hidden_states = xformers.ops.memory_efficient_attention( | ||
| 38 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | ||
| 39 | ) | ||
| 40 | hidden_states = hidden_states.to(query.dtype) | ||
| 41 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 42 | |||
| 43 | # linear proj | ||
| 44 | hidden_states = attn.to_out[0](hidden_states) | ||
| 45 | # dropout | ||
| 46 | hidden_states = attn.to_out[1](hidden_states) | ||
| 47 | return hidden_states | ||
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -710,8 +710,8 @@ def train( | |||
| 710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
| 711 | 711 | ||
| 712 | if compile_unet: | 712 | if compile_unet: |
| 713 | unet = torch.compile(unet, backend='hidet') | 713 | # unet = torch.compile(unet, backend='hidet') |
| 714 | # unet = torch.compile(unet, mode="reduce-overhead") | 714 | unet = torch.compile(unet, mode="reduce-overhead") |
| 715 | 715 | ||
| 716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
| 717 | accelerator=accelerator, | 717 | accelerator=accelerator, |
