summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth/csv.py')
-rw-r--r--data/dreambooth/csv.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 71aa1eb..c0b0067 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule):
70 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 70 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
71 center_crop=self.center_crop, batch_size=self.batch_size) 71 center_crop=self.center_crop, batch_size=self.batch_size)
72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
73 shuffle=True, collate_fn=self.collate_fn) 73 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
75 pin_memory=True, collate_fn=self.collate_fn)
75 76
76 def train_dataloader(self): 77 def train_dataloader(self):
77 return self.train_dataloader_ 78 return self.train_dataloader_