diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 21 |
1 files changed, 8 insertions, 13 deletions
diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -100,20 +100,16 @@ def generate_buckets( | |||
100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
101 | 101 | ||
102 | 102 | ||
103 | def collate_fn( | 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): |
104 | num_class_images: int, | 104 | with_prior = all("class_prompt_ids" in example for example in examples) |
105 | weight_dtype: torch.dtype, | 105 | |
106 | tokenizer: CLIPTokenizer, | ||
107 | examples | ||
108 | ): | ||
109 | prompt_ids = [example["prompt_ids"] for example in examples] | 106 | prompt_ids = [example["prompt_ids"] for example in examples] |
110 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
111 | 108 | ||
112 | input_ids = [example["instance_prompt_ids"] for example in examples] | 109 | input_ids = [example["instance_prompt_ids"] for example in examples] |
113 | pixel_values = [example["instance_images"] for example in examples] | 110 | pixel_values = [example["instance_images"] for example in examples] |
114 | 111 | ||
115 | # concat class and instance examples for prior preservation | 112 | if with_prior: |
116 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
117 | input_ids += [example["class_prompt_ids"] for example in examples] | 113 | input_ids += [example["class_prompt_ids"] for example in examples] |
118 | pixel_values += [example["class_images"] for example in examples] | 114 | pixel_values += [example["class_images"] for example in examples] |
119 | 115 | ||
@@ -125,6 +121,7 @@ def collate_fn( | |||
125 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
126 | 122 | ||
127 | batch = { | 123 | batch = { |
124 | "with_prior": torch.tensor(with_prior), | ||
128 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
129 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
130 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
@@ -166,7 +163,6 @@ class VlpnDataModule(): | |||
166 | seed: Optional[int] = None, | 163 | seed: Optional[int] = None, |
167 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 164 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
168 | dtype: torch.dtype = torch.float32, | 165 | dtype: torch.dtype = torch.float32, |
169 | num_workers: int = 0 | ||
170 | ): | 166 | ): |
171 | super().__init__() | 167 | super().__init__() |
172 | 168 | ||
@@ -194,7 +190,6 @@ class VlpnDataModule(): | |||
194 | self.valid_set_repeat = valid_set_repeat | 190 | self.valid_set_repeat = valid_set_repeat |
195 | self.seed = seed | 191 | self.seed = seed |
196 | self.filter = filter | 192 | self.filter = filter |
197 | self.num_workers = num_workers | ||
198 | self.batch_size = batch_size | 193 | self.batch_size = batch_size |
199 | self.dtype = dtype | 194 | self.dtype = dtype |
200 | 195 | ||
@@ -290,16 +285,16 @@ class VlpnDataModule(): | |||
290 | size=self.size, interpolation=self.interpolation, | 285 | size=self.size, interpolation=self.interpolation, |
291 | ) | 286 | ) |
292 | 287 | ||
293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) | 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) |
294 | 289 | ||
295 | self.train_dataloader = DataLoader( | 290 | self.train_dataloader = DataLoader( |
296 | train_dataset, | 291 | train_dataset, |
297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 292 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
298 | ) | 293 | ) |
299 | 294 | ||
300 | self.val_dataloader = DataLoader( | 295 | self.val_dataloader = DataLoader( |
301 | val_dataset, | 296 | val_dataset, |
302 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
303 | ) | 298 | ) |
304 | 299 | ||
305 | 300 | ||