summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
commitf23fd5184b8ba4ec04506495f4a61726e50756f7 (patch)
treed4c5666b291316ed95437cc1c917b03ef3b679da /data
parentAdded negative prompt support for training scripts (diff)
downloadtextual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip
Small perf improvements
Diffstat (limited to 'data')
-rw-r--r--data/dreambooth/csv.py5
-rw-r--r--data/textual_inversion/csv.py4
2 files changed, 5 insertions, 4 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 71aa1eb..c0b0067 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule):
70 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 70 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
71 center_crop=self.center_crop, batch_size=self.batch_size) 71 center_crop=self.center_crop, batch_size=self.batch_size)
72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
73 shuffle=True, collate_fn=self.collate_fn) 73 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
75 pin_memory=True, collate_fn=self.collate_fn)
75 76
76 def train_dataloader(self): 77 def train_dataloader(self):
77 return self.train_dataloader_ 78 return self.train_dataloader_
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 64f0c28..852b1cb 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule):
60 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 60 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
62 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 62 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) 64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
65 65
66 def train_dataloader(self): 66 def train_dataloader(self):
67 return self.train_dataloader_ 67 return self.train_dataloader_