diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/dreambooth/csv.py | 5 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 4 |
2 files changed, 5 insertions, 4 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 71aa1eb..c0b0067 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 71 | center_crop=self.center_crop, batch_size=self.batch_size) | 71 | center_crop=self.center_crop, batch_size=self.batch_size) |
| 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 73 | shuffle=True, collate_fn=self.collate_fn) | 73 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 75 | pin_memory=True, collate_fn=self.collate_fn) | ||
| 75 | 76 | ||
| 76 | def train_dataloader(self): | 77 | def train_dataloader(self): |
| 77 | return self.train_dataloader_ | 78 | return self.train_dataloader_ |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 64f0c28..852b1cb 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
| 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) |
| 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) |
| 65 | 65 | ||
| 66 | def train_dataloader(self): | 66 | def train_dataloader(self): |
| 67 | return self.train_dataloader_ | 67 | return self.train_dataloader_ |
