From c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 14:48:25 +0200 Subject: Randomize dataset across cycles --- data/csv.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 818fcd9..3af9925 100644 --- a/data/csv.py +++ b/data/csv.py @@ -192,7 +192,7 @@ class VlpnDataModule(): valid_set_size: Optional[int] = None, train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, - seed: Optional[int] = None, + generator: Optional[torch.Generator] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): @@ -224,10 +224,10 @@ class VlpnDataModule(): self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size - self.seed = seed self.filter = filter self.batch_size = batch_size self.dtype = dtype + self.generator = generator def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" @@ -304,16 +304,12 @@ class VlpnDataModule(): train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - generator = torch.Generator(device="cpu") - if self.seed is not None: - generator = generator.manual_seed(self.seed) - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items else: - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) data_train = self.pad_items(data_train, self.num_class_images) @@ -324,7 +320,7 @@ class VlpnDataModule(): data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, fill_batch=True, generator=generator, + batch_size=self.batch_size, fill_batch=True, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) @@ -344,7 +340,7 @@ class VlpnDataModule(): data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, generator=generator, + batch_size=self.batch_size, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, ) -- cgit v1.2.3-54-g00ecf