From 1aace3e44dae0489130039714f67d980628c92ec Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 12:59:08 +0200 Subject: Avoid model recompilation due to varying prompt lengths --- data/csv.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,7 +100,14 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): +def collate_fn( + dtype: torch.dtype, + tokenizer: CLIPTokenizer, + max_token_id_length: Optional[int], + with_guidance: bool, + with_prior_preservation: bool, + examples +): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] @@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) - prompts = unify_input_ids(tokenizer, prompt_ids) - nprompts = unify_input_ids(tokenizer, nprompt_ids) - inputs = unify_input_ids(tokenizer, input_ids) - negative_inputs = unify_input_ids(tokenizer, negative_input_ids) + prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) + nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) + inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) + negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) batch = { "prompt_ids": prompts.input_ids, @@ -176,6 +183,7 @@ class VlpnDataModule(): batch_size: int, data_file: str, tokenizer: CLIPTokenizer, + constant_prompt_length: bool = False, class_subdir: str = "cls", with_guidance: bool = False, num_class_images: int = 1, @@ -212,6 +220,9 @@ class VlpnDataModule(): self.num_class_images = num_class_images self.with_guidance = with_guidance + self.constant_prompt_length = constant_prompt_length + self.max_token_id_length = None + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets @@ -301,14 +312,20 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) self.npgenerator.shuffle(items) + + if self.constant_prompt_length: + all_input_ids = unify_input_ids( + self.tokenizer, + [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] + ).input_ids + self.max_token_id_length = all_input_ids.shape[1] num_images = len(items) - valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items -- cgit v1.2.3-70-g09d2