diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -100,7 +100,14 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): | 103 | def collate_fn( |
| 104 | dtype: torch.dtype, | ||
| 105 | tokenizer: CLIPTokenizer, | ||
| 106 | max_token_id_length: Optional[int], | ||
| 107 | with_guidance: bool, | ||
| 108 | with_prior_preservation: bool, | ||
| 109 | examples | ||
| 110 | ): | ||
| 104 | prompt_ids = [example["prompt_ids"] for example in examples] | 111 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 105 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 106 | 113 | ||
| @@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
| 115 | pixel_values = torch.stack(pixel_values) | 122 | pixel_values = torch.stack(pixel_values) |
| 116 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) | 123 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) |
| 117 | 124 | ||
| 118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
| 119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
| 120 | inputs = unify_input_ids(tokenizer, input_ids) | 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
| 121 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) |
| 122 | 129 | ||
| 123 | batch = { | 130 | batch = { |
| 124 | "prompt_ids": prompts.input_ids, | 131 | "prompt_ids": prompts.input_ids, |
| @@ -176,6 +183,7 @@ class VlpnDataModule(): | |||
| 176 | batch_size: int, | 183 | batch_size: int, |
| 177 | data_file: str, | 184 | data_file: str, |
| 178 | tokenizer: CLIPTokenizer, | 185 | tokenizer: CLIPTokenizer, |
| 186 | constant_prompt_length: bool = False, | ||
| 179 | class_subdir: str = "cls", | 187 | class_subdir: str = "cls", |
| 180 | with_guidance: bool = False, | 188 | with_guidance: bool = False, |
| 181 | num_class_images: int = 1, | 189 | num_class_images: int = 1, |
| @@ -212,6 +220,9 @@ class VlpnDataModule(): | |||
| 212 | self.num_class_images = num_class_images | 220 | self.num_class_images = num_class_images |
| 213 | self.with_guidance = with_guidance | 221 | self.with_guidance = with_guidance |
| 214 | 222 | ||
| 223 | self.constant_prompt_length = constant_prompt_length | ||
| 224 | self.max_token_id_length = None | ||
| 225 | |||
| 215 | self.tokenizer = tokenizer | 226 | self.tokenizer = tokenizer |
| 216 | self.size = size | 227 | self.size = size |
| 217 | self.num_buckets = num_buckets | 228 | self.num_buckets = num_buckets |
| @@ -301,14 +312,20 @@ class VlpnDataModule(): | |||
| 301 | items = self.prepare_items(template, expansions, items) | 312 | items = self.prepare_items(template, expansions, items) |
| 302 | items = self.filter_items(items) | 313 | items = self.filter_items(items) |
| 303 | self.npgenerator.shuffle(items) | 314 | self.npgenerator.shuffle(items) |
| 315 | |||
| 316 | if self.constant_prompt_length: | ||
| 317 | all_input_ids = unify_input_ids( | ||
| 318 | self.tokenizer, | ||
| 319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | ||
| 320 | ).input_ids | ||
| 321 | self.max_token_id_length = all_input_ids.shape[1] | ||
| 304 | 322 | ||
| 305 | num_images = len(items) | 323 | num_images = len(items) |
| 306 | |||
| 307 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| 308 | train_set_size = max(num_images - valid_set_size, 1) | 325 | train_set_size = max(num_images - valid_set_size, 1) |
| 309 | valid_set_size = num_images - train_set_size | 326 | valid_set_size = num_images - train_set_size |
| 310 | 327 | ||
| 311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) |
| 312 | 329 | ||
| 313 | if valid_set_size == 0: | 330 | if valid_set_size == 0: |
| 314 | data_train, data_val = items, items | 331 | data_train, data_val = items, items |
