From f23fd5184b8ba4ec04506495f4a61726e50756f7 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 3 Oct 2022 17:38:44 +0200
Subject: Small perf improvements

---
 data/dreambooth/csv.py        | 5 +++--
 data/textual_inversion/csv.py | 4 ++--
 2 files changed, 5 insertions(+), 4 deletions(-)

(limited to 'data')

diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 71aa1eb..c0b0067 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule):
                                  size=self.size, interpolation=self.interpolation, identifier=self.identifier,
                                  center_crop=self.center_crop, batch_size=self.batch_size)
         self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
-                                            shuffle=True, collate_fn=self.collate_fn)
-        self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
+                                            shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
+        self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
+                                          pin_memory=True, collate_fn=self.collate_fn)
 
     def train_dataloader(self):
         return self.train_dataloader_
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 64f0c28..852b1cb 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule):
                                    placeholder_token=self.placeholder_token, center_crop=self.center_crop)
         val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
                                  placeholder_token=self.placeholder_token, center_crop=self.center_crop)
-        self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
-        self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
+        self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
+        self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
 
     def train_dataloader(self):
         return self.train_dataloader_
-- 
cgit v1.2.3-70-g09d2