From b57ca669a150d9313447612fb8c37668f4f2a80d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 9 Jan 2023 10:19:37 +0100 Subject: Add --valid_set_repeat --- data/csv.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 2f0a392..584a40c 100644 --- a/data/csv.py +++ b/data/csv.py @@ -125,6 +125,7 @@ class VlpnDataModule(): interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, + valid_set_repeat: int = 1, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, @@ -152,6 +153,7 @@ class VlpnDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size + self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter self.collate_fn = collate_fn @@ -243,6 +245,7 @@ class VlpnDataModule(): val_dataset = VlpnDataset( self.data_val, self.prompt_processor, + repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) @@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset): bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, + repeat: int = 1, batch_size: int = 1, num_class_images: int = 0, size: int = 768, @@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): - self.items = items + self.items = items * repeat self.batch_size = batch_size self.prompt_processor = prompt_processor -- cgit v1.2.3-54-g00ecf