summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py
index 81e8b6b..14380e8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,7 +100,14 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): 103def collate_fn(
104 dtype: torch.dtype,
105 tokenizer: CLIPTokenizer,
106 max_token_id_length: Optional[int],
107 with_guidance: bool,
108 with_prior_preservation: bool,
109 examples
110):
104 prompt_ids = [example["prompt_ids"] for example in examples] 111 prompt_ids = [example["prompt_ids"] for example in examples]
105 nprompt_ids = [example["nprompt_ids"] for example in examples] 112 nprompt_ids = [example["nprompt_ids"] for example in examples]
106 113
@@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool
115 pixel_values = torch.stack(pixel_values) 122 pixel_values = torch.stack(pixel_values)
116 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) 123 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format)
117 124
118 prompts = unify_input_ids(tokenizer, prompt_ids) 125 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length)
119 nprompts = unify_input_ids(tokenizer, nprompt_ids) 126 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length)
120 inputs = unify_input_ids(tokenizer, input_ids) 127 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length)
121 negative_inputs = unify_input_ids(tokenizer, negative_input_ids) 128 negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length)
122 129
123 batch = { 130 batch = {
124 "prompt_ids": prompts.input_ids, 131 "prompt_ids": prompts.input_ids,
@@ -176,6 +183,7 @@ class VlpnDataModule():
176 batch_size: int, 183 batch_size: int,
177 data_file: str, 184 data_file: str,
178 tokenizer: CLIPTokenizer, 185 tokenizer: CLIPTokenizer,
186 constant_prompt_length: bool = False,
179 class_subdir: str = "cls", 187 class_subdir: str = "cls",
180 with_guidance: bool = False, 188 with_guidance: bool = False,
181 num_class_images: int = 1, 189 num_class_images: int = 1,
@@ -212,6 +220,9 @@ class VlpnDataModule():
212 self.num_class_images = num_class_images 220 self.num_class_images = num_class_images
213 self.with_guidance = with_guidance 221 self.with_guidance = with_guidance
214 222
223 self.constant_prompt_length = constant_prompt_length
224 self.max_token_id_length = None
225
215 self.tokenizer = tokenizer 226 self.tokenizer = tokenizer
216 self.size = size 227 self.size = size
217 self.num_buckets = num_buckets 228 self.num_buckets = num_buckets
@@ -301,14 +312,20 @@ class VlpnDataModule():
301 items = self.prepare_items(template, expansions, items) 312 items = self.prepare_items(template, expansions, items)
302 items = self.filter_items(items) 313 items = self.filter_items(items)
303 self.npgenerator.shuffle(items) 314 self.npgenerator.shuffle(items)
315
316 if self.constant_prompt_length:
317 all_input_ids = unify_input_ids(
318 self.tokenizer,
319 [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items]
320 ).input_ids
321 self.max_token_id_length = all_input_ids.shape[1]
304 322
305 num_images = len(items) 323 num_images = len(items)
306
307 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 324 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
308 train_set_size = max(num_images - valid_set_size, 1) 325 train_set_size = max(num_images - valid_set_size, 1)
309 valid_set_size = num_images - train_set_size 326 valid_set_size = num_images - train_set_size
310 327
311 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) 328 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0)
312 329
313 if valid_set_size == 0: 330 if valid_set_size == 0:
314 data_train, data_val = items, items 331 data_train, data_val = items, items