diff options
Diffstat (limited to 'data/dreambooth')
| -rw-r--r-- | data/dreambooth/csv.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4087226..08ed49c 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 22 | identifier="*", | 22 | identifier="*", |
| 23 | center_crop=False, | 23 | center_crop=False, |
| 24 | valid_set_size=None, | 24 | valid_set_size=None, |
| 25 | generator=None, | ||
| 25 | collate_fn=None): | 26 | collate_fn=None): |
| 26 | super().__init__() | 27 | super().__init__() |
| 27 | 28 | ||
| @@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 41 | self.center_crop = center_crop | 42 | self.center_crop = center_crop |
| 42 | self.interpolation = interpolation | 43 | self.interpolation = interpolation |
| 43 | self.valid_set_size = valid_set_size | 44 | self.valid_set_size = valid_set_size |
| 45 | self.generator = generator | ||
| 44 | self.collate_fn = collate_fn | 46 | self.collate_fn = collate_fn |
| 45 | self.batch_size = batch_size | 47 | self.batch_size = batch_size |
| 46 | 48 | ||
| @@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 54 | def setup(self, stage=None): | 56 | def setup(self, stage=None): |
| 55 | valid_set_size = int(len(self.data_full) * 0.2) | 57 | valid_set_size = int(len(self.data_full) * 0.2) |
| 56 | if self.valid_set_size: | 58 | if self.valid_set_size: |
| 57 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | 59 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 58 | train_set_size = len(self.data_full) - valid_set_size | 60 | train_set_size = len(self.data_full) - valid_set_size |
| 59 | 61 | ||
| 60 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 62 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
| 61 | 63 | ||
| 62 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 64 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
| 63 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 65 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
