summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 913268f..4ebdc1e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -296,7 +296,7 @@ class VlpnDataModule():
296 data_train, self.tokenizer, 296 data_train, self.tokenizer,
297 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 297 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
298 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 298 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
299 batch_size=self.batch_size, generator=generator, 299 batch_size=self.batch_size, fill_batch=True, generator=generator,
300 size=self.size, interpolation=self.interpolation, 300 size=self.size, interpolation=self.interpolation,
301 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 301 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
302 ) 302 )
@@ -338,6 +338,7 @@ class VlpnDataset(IterableDataset):
338 bucket_max_pixels: Optional[int] = None, 338 bucket_max_pixels: Optional[int] = None,
339 progressive_buckets: bool = False, 339 progressive_buckets: bool = False,
340 batch_size: int = 1, 340 batch_size: int = 1,
341 fill_batch: bool = False,
341 num_class_images: int = 0, 342 num_class_images: int = 0,
342 size: int = 768, 343 size: int = 768,
343 dropout: float = 0, 344 dropout: float = 0,
@@ -347,6 +348,7 @@ class VlpnDataset(IterableDataset):
347 ): 348 ):
348 self.items = items 349 self.items = items
349 self.batch_size = batch_size 350 self.batch_size = batch_size
351 self.fill_batch = fill_batch
350 352
351 self.tokenizer = tokenizer 353 self.tokenizer = tokenizer
352 self.num_class_images = num_class_images 354 self.num_class_images = num_class_images
@@ -408,7 +410,8 @@ class VlpnDataset(IterableDataset):
408 410
409 if len(bucket_items) == 0: 411 if len(bucket_items) == 0:
410 if len(batch) != 0: 412 if len(batch) != 0:
411 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 413 if self.fill_batch:
414 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
412 yield batch 415 yield batch
413 batch = [] 416 batch = []
414 417
@@ -448,5 +451,6 @@ class VlpnDataset(IterableDataset):
448 batch.append(example) 451 batch.append(example)
449 452
450 if len(batch) != 0: 453 if len(batch) != 0:
451 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 454 if self.fill_batch:
455 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
452 yield batch 456 yield batch