summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py21
1 files changed, 8 insertions, 13 deletions
diff --git a/data/csv.py b/data/csv.py
index a3fef30..df3ee77 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,20 +100,16 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn( 103def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples):
104 num_class_images: int, 104 with_prior = all("class_prompt_ids" in example for example in examples)
105 weight_dtype: torch.dtype, 105
106 tokenizer: CLIPTokenizer,
107 examples
108):
109 prompt_ids = [example["prompt_ids"] for example in examples] 106 prompt_ids = [example["prompt_ids"] for example in examples]
110 nprompt_ids = [example["nprompt_ids"] for example in examples] 107 nprompt_ids = [example["nprompt_ids"] for example in examples]
111 108
112 input_ids = [example["instance_prompt_ids"] for example in examples] 109 input_ids = [example["instance_prompt_ids"] for example in examples]
113 pixel_values = [example["instance_images"] for example in examples] 110 pixel_values = [example["instance_images"] for example in examples]
114 111
115 # concat class and instance examples for prior preservation 112 if with_prior:
116 if num_class_images != 0 and "class_prompt_ids" in examples[0]:
117 input_ids += [example["class_prompt_ids"] for example in examples] 113 input_ids += [example["class_prompt_ids"] for example in examples]
118 pixel_values += [example["class_images"] for example in examples] 114 pixel_values += [example["class_images"] for example in examples]
119 115
@@ -125,6 +121,7 @@ def collate_fn(
125 inputs = unify_input_ids(tokenizer, input_ids) 121 inputs = unify_input_ids(tokenizer, input_ids)
126 122
127 batch = { 123 batch = {
124 "with_prior": torch.tensor(with_prior),
128 "prompt_ids": prompts.input_ids, 125 "prompt_ids": prompts.input_ids,
129 "nprompt_ids": nprompts.input_ids, 126 "nprompt_ids": nprompts.input_ids,
130 "input_ids": inputs.input_ids, 127 "input_ids": inputs.input_ids,
@@ -166,7 +163,6 @@ class VlpnDataModule():
166 seed: Optional[int] = None, 163 seed: Optional[int] = None,
167 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 164 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
168 dtype: torch.dtype = torch.float32, 165 dtype: torch.dtype = torch.float32,
169 num_workers: int = 0
170 ): 166 ):
171 super().__init__() 167 super().__init__()
172 168
@@ -194,7 +190,6 @@ class VlpnDataModule():
194 self.valid_set_repeat = valid_set_repeat 190 self.valid_set_repeat = valid_set_repeat
195 self.seed = seed 191 self.seed = seed
196 self.filter = filter 192 self.filter = filter
197 self.num_workers = num_workers
198 self.batch_size = batch_size 193 self.batch_size = batch_size
199 self.dtype = dtype 194 self.dtype = dtype
200 195
@@ -290,16 +285,16 @@ class VlpnDataModule():
290 size=self.size, interpolation=self.interpolation, 285 size=self.size, interpolation=self.interpolation,
291 ) 286 )
292 287
293 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) 288 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer)
294 289
295 self.train_dataloader = DataLoader( 290 self.train_dataloader = DataLoader(
296 train_dataset, 291 train_dataset,
297 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers 292 batch_size=None, pin_memory=True, collate_fn=collate_fn_
298 ) 293 )
299 294
300 self.val_dataloader = DataLoader( 295 self.val_dataloader = DataLoader(
301 val_dataset, 296 val_dataset,
302 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers 297 batch_size=None, pin_memory=True, collate_fn=collate_fn_
303 ) 298 )
304 299
305 300