summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/dreambooth/csv.py9
-rw-r--r--data/textual_inversion/csv.py11
2 files changed, 15 insertions, 5 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 1676d35..4087226 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -21,6 +21,7 @@ class CSVDataModule(pl.LightningDataModule):
21 interpolation="bicubic", 21 interpolation="bicubic",
22 identifier="*", 22 identifier="*",
23 center_crop=False, 23 center_crop=False,
24 valid_set_size=None,
24 collate_fn=None): 25 collate_fn=None):
25 super().__init__() 26 super().__init__()
26 27
@@ -39,6 +40,7 @@ class CSVDataModule(pl.LightningDataModule):
39 self.identifier = identifier 40 self.identifier = identifier
40 self.center_crop = center_crop 41 self.center_crop = center_crop
41 self.interpolation = interpolation 42 self.interpolation = interpolation
43 self.valid_set_size = valid_set_size
42 self.collate_fn = collate_fn 44 self.collate_fn = collate_fn
43 self.batch_size = batch_size 45 self.batch_size = batch_size
44 46
@@ -50,8 +52,11 @@ class CSVDataModule(pl.LightningDataModule):
50 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 52 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
51 53
52 def setup(self, stage=None): 54 def setup(self, stage=None):
53 train_set_size = int(len(self.data_full) * 0.8) 55 valid_set_size = int(len(self.data_full) * 0.2)
54 valid_set_size = len(self.data_full) - train_set_size 56 if self.valid_set_size:
57 valid_set_size = math.min(valid_set_size, self.valid_set_size)
58 train_set_size = len(self.data_full) - valid_set_size
59
55 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 60 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
56 61
57 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 62 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index f306c7a..e082511 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -18,7 +18,8 @@ class CSVDataModule(pl.LightningDataModule):
18 repeats=100, 18 repeats=100,
19 interpolation="bicubic", 19 interpolation="bicubic",
20 placeholder_token="*", 20 placeholder_token="*",
21 center_crop=False): 21 center_crop=False,
22 valid_set_size=None):
22 super().__init__() 23 super().__init__()
23 24
24 self.data_file = Path(data_file) 25 self.data_file = Path(data_file)
@@ -33,6 +34,7 @@ class CSVDataModule(pl.LightningDataModule):
33 self.placeholder_token = placeholder_token 34 self.placeholder_token = placeholder_token
34 self.center_crop = center_crop 35 self.center_crop = center_crop
35 self.interpolation = interpolation 36 self.interpolation = interpolation
37 self.valid_set_size = valid_set_size
36 38
37 self.batch_size = batch_size 39 self.batch_size = batch_size
38 40
@@ -44,8 +46,11 @@ class CSVDataModule(pl.LightningDataModule):
44 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 46 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
45 47
46 def setup(self, stage=None): 48 def setup(self, stage=None):
47 train_set_size = int(len(self.data_full) * 0.8) 49 valid_set_size = int(len(self.data_full) * 0.2)
48 valid_set_size = len(self.data_full) - train_set_size 50 if self.valid_set_size:
51 valid_set_size = math.min(valid_set_size, self.valid_set_size)
52 train_set_size = len(self.data_full) - valid_set_size
53
49 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 54 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
50 55
51 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 56 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,