diff options
Diffstat (limited to 'data/textual_inversion')
| -rw-r--r-- | data/textual_inversion/csv.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index f306c7a..e082511 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -18,7 +18,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 18 | repeats=100, | 18 | repeats=100, |
| 19 | interpolation="bicubic", | 19 | interpolation="bicubic", |
| 20 | placeholder_token="*", | 20 | placeholder_token="*", |
| 21 | center_crop=False): | 21 | center_crop=False, |
| 22 | valid_set_size=None): | ||
| 22 | super().__init__() | 23 | super().__init__() |
| 23 | 24 | ||
| 24 | self.data_file = Path(data_file) | 25 | self.data_file = Path(data_file) |
| @@ -33,6 +34,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 33 | self.placeholder_token = placeholder_token | 34 | self.placeholder_token = placeholder_token |
| 34 | self.center_crop = center_crop | 35 | self.center_crop = center_crop |
| 35 | self.interpolation = interpolation | 36 | self.interpolation = interpolation |
| 37 | self.valid_set_size = valid_set_size | ||
| 36 | 38 | ||
| 37 | self.batch_size = batch_size | 39 | self.batch_size = batch_size |
| 38 | 40 | ||
| @@ -44,8 +46,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 44 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 46 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] |
| 45 | 47 | ||
| 46 | def setup(self, stage=None): | 48 | def setup(self, stage=None): |
| 47 | train_set_size = int(len(self.data_full) * 0.8) | 49 | valid_set_size = int(len(self.data_full) * 0.2) |
| 48 | valid_set_size = len(self.data_full) - train_set_size | 50 | if self.valid_set_size: |
| 51 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | ||
| 52 | train_set_size = len(self.data_full) - valid_set_size | ||
| 53 | |||
| 49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 54 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
| 50 | 55 | ||
| 51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 56 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
