summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth/csv.py')
-rw-r--r--data/dreambooth/csv.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 4087226..08ed49c 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule):
22 identifier="*", 22 identifier="*",
23 center_crop=False, 23 center_crop=False,
24 valid_set_size=None, 24 valid_set_size=None,
25 generator=None,
25 collate_fn=None): 26 collate_fn=None):
26 super().__init__() 27 super().__init__()
27 28
@@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule):
41 self.center_crop = center_crop 42 self.center_crop = center_crop
42 self.interpolation = interpolation 43 self.interpolation = interpolation
43 self.valid_set_size = valid_set_size 44 self.valid_set_size = valid_set_size
45 self.generator = generator
44 self.collate_fn = collate_fn 46 self.collate_fn = collate_fn
45 self.batch_size = batch_size 47 self.batch_size = batch_size
46 48
@@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule):
54 def setup(self, stage=None): 56 def setup(self, stage=None):
55 valid_set_size = int(len(self.data_full) * 0.2) 57 valid_set_size = int(len(self.data_full) * 0.2)
56 if self.valid_set_size: 58 if self.valid_set_size:
57 valid_set_size = math.min(valid_set_size, self.valid_set_size) 59 valid_set_size = min(valid_set_size, self.valid_set_size)
58 train_set_size = len(self.data_full) - valid_set_size 60 train_set_size = len(self.data_full) - valid_set_size
59 61
60 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 62 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
61 63
62 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 64 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
63 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 65 class_data_root=self.class_data_root, class_prompt=self.class_prompt,