diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 21 |
1 files changed, 8 insertions, 13 deletions
diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -100,20 +100,16 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn( | 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): |
| 104 | num_class_images: int, | 104 | with_prior = all("class_prompt_ids" in example for example in examples) |
| 105 | weight_dtype: torch.dtype, | 105 | |
| 106 | tokenizer: CLIPTokenizer, | ||
| 107 | examples | ||
| 108 | ): | ||
| 109 | prompt_ids = [example["prompt_ids"] for example in examples] | 106 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 110 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 111 | 108 | ||
| 112 | input_ids = [example["instance_prompt_ids"] for example in examples] | 109 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 113 | pixel_values = [example["instance_images"] for example in examples] | 110 | pixel_values = [example["instance_images"] for example in examples] |
| 114 | 111 | ||
| 115 | # concat class and instance examples for prior preservation | 112 | if with_prior: |
| 116 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 117 | input_ids += [example["class_prompt_ids"] for example in examples] | 113 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 118 | pixel_values += [example["class_images"] for example in examples] | 114 | pixel_values += [example["class_images"] for example in examples] |
| 119 | 115 | ||
| @@ -125,6 +121,7 @@ def collate_fn( | |||
| 125 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
| 126 | 122 | ||
| 127 | batch = { | 123 | batch = { |
| 124 | "with_prior": torch.tensor(with_prior), | ||
| 128 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
| 129 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
| 130 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
| @@ -166,7 +163,6 @@ class VlpnDataModule(): | |||
| 166 | seed: Optional[int] = None, | 163 | seed: Optional[int] = None, |
| 167 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 164 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 168 | dtype: torch.dtype = torch.float32, | 165 | dtype: torch.dtype = torch.float32, |
| 169 | num_workers: int = 0 | ||
| 170 | ): | 166 | ): |
| 171 | super().__init__() | 167 | super().__init__() |
| 172 | 168 | ||
| @@ -194,7 +190,6 @@ class VlpnDataModule(): | |||
| 194 | self.valid_set_repeat = valid_set_repeat | 190 | self.valid_set_repeat = valid_set_repeat |
| 195 | self.seed = seed | 191 | self.seed = seed |
| 196 | self.filter = filter | 192 | self.filter = filter |
| 197 | self.num_workers = num_workers | ||
| 198 | self.batch_size = batch_size | 193 | self.batch_size = batch_size |
| 199 | self.dtype = dtype | 194 | self.dtype = dtype |
| 200 | 195 | ||
| @@ -290,16 +285,16 @@ class VlpnDataModule(): | |||
| 290 | size=self.size, interpolation=self.interpolation, | 285 | size=self.size, interpolation=self.interpolation, |
| 291 | ) | 286 | ) |
| 292 | 287 | ||
| 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) | 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) |
| 294 | 289 | ||
| 295 | self.train_dataloader = DataLoader( | 290 | self.train_dataloader = DataLoader( |
| 296 | train_dataset, | 291 | train_dataset, |
| 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 292 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 298 | ) | 293 | ) |
| 299 | 294 | ||
| 300 | self.val_dataloader = DataLoader( | 295 | self.val_dataloader = DataLoader( |
| 301 | val_dataset, | 296 | val_dataset, |
| 302 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 303 | ) | 298 | ) |
| 304 | 299 | ||
| 305 | 300 | ||
