From 1aace3e44dae0489130039714f67d980628c92ec Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 12:59:08 +0200 Subject: Avoid model recompilation due to varying prompt lengths --- data/csv.py | 31 +++++++++++++++++++++------ models/clip/util.py | 23 +++++++++++++------- train_lora.py | 3 ++- training/attention_processor.py | 47 ----------------------------------------- training/functional.py | 4 ++-- 5 files changed, 43 insertions(+), 65 deletions(-) delete mode 100644 training/attention_processor.py diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,7 +100,14 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): +def collate_fn( + dtype: torch.dtype, + tokenizer: CLIPTokenizer, + max_token_id_length: Optional[int], + with_guidance: bool, + with_prior_preservation: bool, + examples +): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] @@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) - prompts = unify_input_ids(tokenizer, prompt_ids) - nprompts = unify_input_ids(tokenizer, nprompt_ids) - inputs = unify_input_ids(tokenizer, input_ids) - negative_inputs = unify_input_ids(tokenizer, negative_input_ids) + prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) + nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) + inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) + negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) batch = { "prompt_ids": prompts.input_ids, @@ -176,6 +183,7 @@ class VlpnDataModule(): batch_size: int, data_file: str, tokenizer: CLIPTokenizer, + constant_prompt_length: bool = False, class_subdir: str = "cls", with_guidance: bool = False, num_class_images: int = 1, @@ -212,6 +220,9 @@ class VlpnDataModule(): self.num_class_images = num_class_images self.with_guidance = with_guidance + self.constant_prompt_length = constant_prompt_length + self.max_token_id_length = None + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets @@ -301,14 +312,20 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) self.npgenerator.shuffle(items) + + if self.constant_prompt_length: + all_input_ids = unify_input_ids( + self.tokenizer, + [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] + ).input_ids + self.max_token_id_length = all_input_ids.shape[1] num_images = len(items) - valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -5,14 +5,21 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel -def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): - return tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - pad_to_multiple_of=tokenizer.model_max_length, - return_tensors="pt" - ) - +def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): + if max_length is None: + return tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=tokenizer.model_max_length, + return_tensors="pt" + ) + else: + return tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=max_length, + return_tensors="pt" + ) def get_extended_embeddings( text_encoder: CLIPTextModel, diff --git a/train_lora.py b/train_lora.py index a58bef7..12d7e72 100644 --- a/train_lora.py +++ b/train_lora.py @@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch._dynamo.config.log_level = logging.WARNING -torch._dynamo.config.suppress_errors = True +# torch._dynamo.config.suppress_errors = True hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) @@ -992,6 +992,7 @@ def main(): VlpnDataModule, data_file=args.train_data_file, tokenizer=tokenizer, + constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, diff --git a/training/attention_processor.py b/training/attention_processor.py deleted file mode 100644 index 4309bd4..0000000 --- a/training/attention_processor.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable, Optional, Union - -import xformers -import xformers.ops - -from diffusers.models.attention_processor import Attention - - -class XFormersAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - query = query.to(key.dtype) - value = value.to(key.dtype) - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states diff --git a/training/functional.py b/training/functional.py index fd3f9f4..10560e5 100644 --- a/training/functional.py +++ b/training/functional.py @@ -710,8 +710,8 @@ def train( vae = torch.compile(vae, backend='hidet') if compile_unet: - unet = torch.compile(unet, backend='hidet') - # unet = torch.compile(unet, mode="reduce-overhead") + # unet = torch.compile(unet, backend='hidet') + unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, -- cgit v1.2.3-70-g09d2