diff options
author | Volpeon <git@volpeon.ink> | 2023-02-15 10:44:17 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-15 10:44:17 +0100 |
commit | 1dab2629d49bb4b8aa35ac7a08b0cc1a440a96ca (patch) | |
tree | 024eba03176352feab6744a7e3c4cee8e84385da /data | |
parent | Dataset: Repeat data to fill batch to batch_size (diff) | |
download | textual-inversion-diff-1dab2629d49bb4b8aa35ac7a08b0cc1a440a96ca.tar.gz textual-inversion-diff-1dab2629d49bb4b8aa35ac7a08b0cc1a440a96ca.tar.bz2 textual-inversion-diff-1dab2629d49bb4b8aa35ac7a08b0cc1a440a96ca.zip |
Better batch filling behavior
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 913268f..4ebdc1e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -296,7 +296,7 @@ class VlpnDataModule(): | |||
296 | data_train, self.tokenizer, | 296 | data_train, self.tokenizer, |
297 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 297 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
298 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 298 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
299 | batch_size=self.batch_size, generator=generator, | 299 | batch_size=self.batch_size, fill_batch=True, generator=generator, |
300 | size=self.size, interpolation=self.interpolation, | 300 | size=self.size, interpolation=self.interpolation, |
301 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 301 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
302 | ) | 302 | ) |
@@ -338,6 +338,7 @@ class VlpnDataset(IterableDataset): | |||
338 | bucket_max_pixels: Optional[int] = None, | 338 | bucket_max_pixels: Optional[int] = None, |
339 | progressive_buckets: bool = False, | 339 | progressive_buckets: bool = False, |
340 | batch_size: int = 1, | 340 | batch_size: int = 1, |
341 | fill_batch: bool = False, | ||
341 | num_class_images: int = 0, | 342 | num_class_images: int = 0, |
342 | size: int = 768, | 343 | size: int = 768, |
343 | dropout: float = 0, | 344 | dropout: float = 0, |
@@ -347,6 +348,7 @@ class VlpnDataset(IterableDataset): | |||
347 | ): | 348 | ): |
348 | self.items = items | 349 | self.items = items |
349 | self.batch_size = batch_size | 350 | self.batch_size = batch_size |
351 | self.fill_batch = fill_batch | ||
350 | 352 | ||
351 | self.tokenizer = tokenizer | 353 | self.tokenizer = tokenizer |
352 | self.num_class_images = num_class_images | 354 | self.num_class_images = num_class_images |
@@ -408,7 +410,8 @@ class VlpnDataset(IterableDataset): | |||
408 | 410 | ||
409 | if len(bucket_items) == 0: | 411 | if len(bucket_items) == 0: |
410 | if len(batch) != 0: | 412 | if len(batch) != 0: |
411 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 413 | if self.fill_batch: |
414 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
412 | yield batch | 415 | yield batch |
413 | batch = [] | 416 | batch = [] |
414 | 417 | ||
@@ -448,5 +451,6 @@ class VlpnDataset(IterableDataset): | |||
448 | batch.append(example) | 451 | batch.append(example) |
449 | 452 | ||
450 | if len(batch) != 0: | 453 | if len(batch) != 0: |
451 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 454 | if self.fill_batch: |
455 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
452 | yield batch | 456 | yield batch |