diff options
Diffstat (limited to 'data/textual_inversion')
| -rw-r--r-- | data/textual_inversion/csv.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 852b1cb..4c5e27e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 52 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
| 53 | if self.valid_set_size: | 53 | if self.valid_set_size: |
| 54 | valid_set_size = min(valid_set_size, self.valid_set_size) | 54 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 55 | valid_set_size = max(valid_set_size, 1) | ||
| 55 | train_set_size = len(self.data_full) - valid_set_size | 56 | train_set_size = len(self.data_full) - valid_set_size |
| 56 | 57 | ||
| 57 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) | 58 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
| 58 | 59 | ||
| 59 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 60 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 61 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 62 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 63 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) | 64 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) |
| 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) | 65 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) |
