diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index 81e8b6b..14380e8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -100,7 +100,14 @@ def generate_buckets( | |||
100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
101 | 101 | ||
102 | 102 | ||
103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): | 103 | def collate_fn( |
104 | dtype: torch.dtype, | ||
105 | tokenizer: CLIPTokenizer, | ||
106 | max_token_id_length: Optional[int], | ||
107 | with_guidance: bool, | ||
108 | with_prior_preservation: bool, | ||
109 | examples | ||
110 | ): | ||
104 | prompt_ids = [example["prompt_ids"] for example in examples] | 111 | prompt_ids = [example["prompt_ids"] for example in examples] |
105 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
106 | 113 | ||
@@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
115 | pixel_values = torch.stack(pixel_values) | 122 | pixel_values = torch.stack(pixel_values) |
116 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) | 123 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) |
117 | 124 | ||
118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
120 | inputs = unify_input_ids(tokenizer, input_ids) | 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
121 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) |
122 | 129 | ||
123 | batch = { | 130 | batch = { |
124 | "prompt_ids": prompts.input_ids, | 131 | "prompt_ids": prompts.input_ids, |
@@ -176,6 +183,7 @@ class VlpnDataModule(): | |||
176 | batch_size: int, | 183 | batch_size: int, |
177 | data_file: str, | 184 | data_file: str, |
178 | tokenizer: CLIPTokenizer, | 185 | tokenizer: CLIPTokenizer, |
186 | constant_prompt_length: bool = False, | ||
179 | class_subdir: str = "cls", | 187 | class_subdir: str = "cls", |
180 | with_guidance: bool = False, | 188 | with_guidance: bool = False, |
181 | num_class_images: int = 1, | 189 | num_class_images: int = 1, |
@@ -212,6 +220,9 @@ class VlpnDataModule(): | |||
212 | self.num_class_images = num_class_images | 220 | self.num_class_images = num_class_images |
213 | self.with_guidance = with_guidance | 221 | self.with_guidance = with_guidance |
214 | 222 | ||
223 | self.constant_prompt_length = constant_prompt_length | ||
224 | self.max_token_id_length = None | ||
225 | |||
215 | self.tokenizer = tokenizer | 226 | self.tokenizer = tokenizer |
216 | self.size = size | 227 | self.size = size |
217 | self.num_buckets = num_buckets | 228 | self.num_buckets = num_buckets |
@@ -301,14 +312,20 @@ class VlpnDataModule(): | |||
301 | items = self.prepare_items(template, expansions, items) | 312 | items = self.prepare_items(template, expansions, items) |
302 | items = self.filter_items(items) | 313 | items = self.filter_items(items) |
303 | self.npgenerator.shuffle(items) | 314 | self.npgenerator.shuffle(items) |
315 | |||
316 | if self.constant_prompt_length: | ||
317 | all_input_ids = unify_input_ids( | ||
318 | self.tokenizer, | ||
319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | ||
320 | ).input_ids | ||
321 | self.max_token_id_length = all_input_ids.shape[1] | ||
304 | 322 | ||
305 | num_images = len(items) | 323 | num_images = len(items) |
306 | |||
307 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
308 | train_set_size = max(num_images - valid_set_size, 1) | 325 | train_set_size = max(num_images - valid_set_size, 1) |
309 | valid_set_size = num_images - train_set_size | 326 | valid_set_size = num_images - train_set_size |
310 | 327 | ||
311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) |
312 | 329 | ||
313 | if valid_set_size == 0: | 330 | if valid_set_size == 0: |
314 | data_train, data_val = items, items | 331 | data_train, data_val = items, items |