summaryrefslogtreecommitdiffstats
path: root/data/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'data/textual_inversion')
-rw-r--r--data/textual_inversion/csv.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 64f0c28..852b1cb 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule):
60 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 60 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
62 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 62 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) 64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
65 65
66 def train_dataloader(self): 66 def train_dataloader(self):
67 return self.train_dataloader_ 67 return self.train_dataloader_