summaryrefslogtreecommitdiffstats
path: root/data/textual_inversion/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/textual_inversion/csv.py')
-rw-r--r--data/textual_inversion/csv.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 852b1cb..4c5e27e 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule):
52 valid_set_size = int(len(self.data_full) * 0.2) 52 valid_set_size = int(len(self.data_full) * 0.2)
53 if self.valid_set_size: 53 if self.valid_set_size:
54 valid_set_size = min(valid_set_size, self.valid_set_size) 54 valid_set_size = min(valid_set_size, self.valid_set_size)
55 valid_set_size = max(valid_set_size, 1)
55 train_set_size = len(self.data_full) - valid_set_size 56 train_set_size = len(self.data_full) - valid_set_size
56 57
57 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) 58 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
58 59
59 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 60 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
60 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 61 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 62 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
62 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 63 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) 64 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) 65 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)