From 1dab2629d49bb4b8aa35ac7a08b0cc1a440a96ca Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 15 Feb 2023 10:44:17 +0100 Subject: Better batch filling behavior --- data/csv.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/data/csv.py b/data/csv.py index 913268f..4ebdc1e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -296,7 +296,7 @@ class VlpnDataModule(): data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, generator=generator, + batch_size=self.batch_size, fill_batch=True, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) @@ -338,6 +338,7 @@ class VlpnDataset(IterableDataset): bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, + fill_batch: bool = False, num_class_images: int = 0, size: int = 768, dropout: float = 0, @@ -347,6 +348,7 @@ class VlpnDataset(IterableDataset): ): self.items = items self.batch_size = batch_size + self.fill_batch = fill_batch self.tokenizer = tokenizer self.num_class_images = num_class_images @@ -408,7 +410,8 @@ class VlpnDataset(IterableDataset): if len(bucket_items) == 0: if len(batch) != 0: - batch = list(itertools.islice(itertools.cycle(batch), batch_size)) + if self.fill_batch: + batch = list(itertools.islice(itertools.cycle(batch), batch_size)) yield batch batch = [] @@ -448,5 +451,6 @@ class VlpnDataset(IterableDataset): batch.append(example) if len(batch) != 0: - batch = list(itertools.islice(itertools.cycle(batch), batch_size)) + if self.fill_batch: + batch = list(itertools.islice(itertools.cycle(batch), batch_size)) yield batch -- cgit v1.2.3-54-g00ecf