diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 39 |
1 files changed, 21 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index dec66d7..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -174,7 +174,8 @@ class VlpnDataModule(): | |||
| 174 | interpolation: str = "bicubic", | 174 | interpolation: str = "bicubic", |
| 175 | template_key: str = "template", | 175 | template_key: str = "template", |
| 176 | valid_set_size: Optional[int] = None, | 176 | valid_set_size: Optional[int] = None, |
| 177 | valid_set_repeat: int = 1, | 177 | train_set_pad: Optional[int] = None, |
| 178 | valid_set_pad: Optional[int] = None, | ||
| 178 | seed: Optional[int] = None, | 179 | seed: Optional[int] = None, |
| 179 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 180 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 180 | dtype: torch.dtype = torch.float32, | 181 | dtype: torch.dtype = torch.float32, |
| @@ -202,7 +203,8 @@ class VlpnDataModule(): | |||
| 202 | self.template_key = template_key | 203 | self.template_key = template_key |
| 203 | self.interpolation = interpolation | 204 | self.interpolation = interpolation |
| 204 | self.valid_set_size = valid_set_size | 205 | self.valid_set_size = valid_set_size |
| 205 | self.valid_set_repeat = valid_set_repeat | 206 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
| 207 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | ||
| 206 | self.seed = seed | 208 | self.seed = seed |
| 207 | self.filter = filter | 209 | self.filter = filter |
| 208 | self.batch_size = batch_size | 210 | self.batch_size = batch_size |
| @@ -267,9 +269,6 @@ class VlpnDataModule(): | |||
| 267 | items = self.prepare_items(template, expansions, items) | 269 | items = self.prepare_items(template, expansions, items) |
| 268 | items = self.filter_items(items) | 270 | items = self.filter_items(items) |
| 269 | 271 | ||
| 270 | if (len(items) < self.batch_size): | ||
| 271 | items = (items * self.batch_size)[:self.batch_size] | ||
| 272 | |||
| 273 | num_images = len(items) | 272 | num_images = len(items) |
| 274 | 273 | ||
| 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 274 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| @@ -283,14 +282,17 @@ class VlpnDataModule(): | |||
| 283 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 284 | 283 | ||
| 285 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
| 286 | data_train, data_val = items, [] | 285 | data_train, data_val = items, items[:1] |
| 287 | else: | 286 | else: |
| 288 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
| 289 | 288 | ||
| 290 | self.data_train = self.pad_items(data_train, self.num_class_images) | 289 | data_train = self.pad_items(data_train, self.num_class_images) |
| 290 | |||
| 291 | if len(data_train) < self.train_set_pad: | ||
| 292 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | ||
| 291 | 293 | ||
| 292 | train_dataset = VlpnDataset( | 294 | self.train_dataset = VlpnDataset( |
| 293 | self.data_train, self.tokenizer, | 295 | data_train, self.tokenizer, |
| 294 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 296 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 295 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 296 | batch_size=self.batch_size, generator=generator, | 298 | batch_size=self.batch_size, generator=generator, |
| @@ -299,24 +301,26 @@ class VlpnDataModule(): | |||
| 299 | ) | 301 | ) |
| 300 | 302 | ||
| 301 | self.train_dataloader = DataLoader( | 303 | self.train_dataloader = DataLoader( |
| 302 | train_dataset, | 304 | self.train_dataset, |
| 303 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 304 | ) | 306 | ) |
| 305 | 307 | ||
| 306 | if valid_set_size != 0: | 308 | if len(data_val) != 0: |
| 307 | self.data_val = self.pad_items(data_val) | 309 | data_val = self.pad_items(data_val) |
| 310 | |||
| 311 | if len(data_val) < self.valid_set_pad: | ||
| 312 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | ||
| 308 | 313 | ||
| 309 | val_dataset = VlpnDataset( | 314 | self.val_dataset = VlpnDataset( |
| 310 | self.data_val, self.tokenizer, | 315 | data_val, self.tokenizer, |
| 311 | num_buckets=self.num_buckets, progressive_buckets=True, | 316 | num_buckets=self.num_buckets, progressive_buckets=True, |
| 312 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 317 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 313 | repeat=self.valid_set_repeat, | ||
| 314 | batch_size=self.batch_size, generator=generator, | 318 | batch_size=self.batch_size, generator=generator, |
| 315 | size=self.size, interpolation=self.interpolation, | 319 | size=self.size, interpolation=self.interpolation, |
| 316 | ) | 320 | ) |
| 317 | 321 | ||
| 318 | self.val_dataloader = DataLoader( | 322 | self.val_dataloader = DataLoader( |
| 319 | val_dataset, | 323 | self.val_dataset, |
| 320 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 324 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 321 | ) | 325 | ) |
| 322 | else: | 326 | else: |
| @@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset): | |||
| 332 | bucket_step_size: int = 64, | 336 | bucket_step_size: int = 64, |
| 333 | bucket_max_pixels: Optional[int] = None, | 337 | bucket_max_pixels: Optional[int] = None, |
| 334 | progressive_buckets: bool = False, | 338 | progressive_buckets: bool = False, |
| 335 | repeat: int = 1, | ||
| 336 | batch_size: int = 1, | 339 | batch_size: int = 1, |
| 337 | num_class_images: int = 0, | 340 | num_class_images: int = 0, |
| 338 | size: int = 768, | 341 | size: int = 768, |
| @@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset): | |||
| 341 | interpolation: str = "bicubic", | 344 | interpolation: str = "bicubic", |
| 342 | generator: Optional[torch.Generator] = None, | 345 | generator: Optional[torch.Generator] = None, |
| 343 | ): | 346 | ): |
| 344 | self.items = items * repeat | 347 | self.items = items |
| 345 | self.batch_size = batch_size | 348 | self.batch_size = batch_size |
| 346 | 349 | ||
| 347 | self.tokenizer = tokenizer | 350 | self.tokenizer = tokenizer |
