From 6c072fe50b3bfc561f22e5d591212d30de3c2dd2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 12:08:16 +0200 Subject: Fixed euler_a generator argument --- data/dreambooth/csv.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'data/dreambooth') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4087226..08ed49c 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule): identifier="*", center_crop=False, valid_set_size=None, + generator=None, collate_fn=None): super().__init__() @@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule): self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size + self.generator = generator self.collate_fn = collate_fn self.batch_size = batch_size @@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule): def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: - valid_set_size = math.min(valid_set_size, self.valid_set_size) + valid_set_size = min(valid_set_size, self.valid_set_size) train_set_size = len(self.data_full) - valid_set_size - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, -- cgit v1.2.3-70-g09d2