summaryrefslogtreecommitdiffstats
path: root/data/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'data/textual_inversion')
-rw-r--r--data/textual_inversion/csv.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index e082511..3ac57df 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -19,7 +19,8 @@ class CSVDataModule(pl.LightningDataModule):
19 interpolation="bicubic", 19 interpolation="bicubic",
20 placeholder_token="*", 20 placeholder_token="*",
21 center_crop=False, 21 center_crop=False,
22 valid_set_size=None): 22 valid_set_size=None,
23 generator=None):
23 super().__init__() 24 super().__init__()
24 25
25 self.data_file = Path(data_file) 26 self.data_file = Path(data_file)
@@ -35,6 +36,7 @@ class CSVDataModule(pl.LightningDataModule):
35 self.center_crop = center_crop 36 self.center_crop = center_crop
36 self.interpolation = interpolation 37 self.interpolation = interpolation
37 self.valid_set_size = valid_set_size 38 self.valid_set_size = valid_set_size
39 self.generator = generator
38 40
39 self.batch_size = batch_size 41 self.batch_size = batch_size
40 42
@@ -48,10 +50,10 @@ class CSVDataModule(pl.LightningDataModule):
48 def setup(self, stage=None): 50 def setup(self, stage=None):
49 valid_set_size = int(len(self.data_full) * 0.2) 51 valid_set_size = int(len(self.data_full) * 0.2)
50 if self.valid_set_size: 52 if self.valid_set_size:
51 valid_set_size = math.min(valid_set_size, self.valid_set_size) 53 valid_set_size = min(valid_set_size, self.valid_set_size)
52 train_set_size = len(self.data_full) - valid_set_size 54 train_set_size = len(self.data_full) - valid_set_size
53 55
54 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 56 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
55 57
56 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 58 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
57 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 59 placeholder_token=self.placeholder_token, center_crop=self.center_crop)