From 0bc909409648a3cae0061c3de2b39e486473ae39 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 17:57:05 +0200 Subject: Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth --- data/csv.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): center_crop: bool = False, valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, - collate_fn=None + collate_fn=None, + num_workers: int = 0 ): super().__init__() @@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn + self.num_workers = num_workers self.batch_size = batch_size def prepare_subdata(self, template, data, num_class_images=1): @@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn) + shuffle=True, pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn) + pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers) def train_dataloader(self): return self.train_dataloader_ -- cgit v1.2.3-54-g00ecf