diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/dreambooth/csv.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 85ed4a5..99bcf12 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -1,3 +1,4 @@ | |||
1 | import math | ||
1 | import os | 2 | import os |
2 | import pandas as pd | 3 | import pandas as pd |
3 | from pathlib import Path | 4 | from pathlib import Path |
@@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 58 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
58 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 59 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
59 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 60 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
60 | center_crop=self.center_crop, repeats=self.repeats) | 61 | center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) |
61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 62 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, |
62 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | ||
63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
64 | center_crop=self.center_crop) | 64 | center_crop=self.center_crop, batch_size=self.batch_size) |
65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
66 | shuffle=True, collate_fn=self.collate_fn) | 66 | shuffle=True, collate_fn=self.collate_fn) |
67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) |
@@ -85,22 +85,24 @@ class CSVDataset(Dataset): | |||
85 | interpolation="bicubic", | 85 | interpolation="bicubic", |
86 | identifier="*", | 86 | identifier="*", |
87 | center_crop=False, | 87 | center_crop=False, |
88 | batch_size=1, | ||
88 | ): | 89 | ): |
89 | 90 | ||
90 | self.data = data | 91 | self.data = data |
91 | self.tokenizer = tokenizer | 92 | self.tokenizer = tokenizer |
92 | self.instance_prompt = instance_prompt | 93 | self.instance_prompt = instance_prompt |
94 | self.identifier = identifier | ||
95 | self.batch_size = batch_size | ||
96 | self.cache = {} | ||
93 | 97 | ||
94 | self.num_instance_images = len(self.data) | 98 | self.num_instance_images = len(self.data) |
95 | self._length = self.num_instance_images * repeats | 99 | self._length = self.num_instance_images * repeats |
96 | 100 | ||
97 | self.identifier = identifier | ||
98 | |||
99 | if class_data_root is not None: | 101 | if class_data_root is not None: |
100 | self.class_data_root = Path(class_data_root) | 102 | self.class_data_root = Path(class_data_root) |
101 | self.class_data_root.mkdir(parents=True, exist_ok=True) | 103 | self.class_data_root.mkdir(parents=True, exist_ok=True) |
102 | 104 | ||
103 | self.class_images = list(Path(class_data_root).iterdir()) | 105 | self.class_images = list(self.class_data_root.iterdir()) |
104 | self.num_class_images = len(self.class_images) | 106 | self.num_class_images = len(self.class_images) |
105 | self._length = max(self.num_class_images, self.num_instance_images) | 107 | self._length = max(self.num_class_images, self.num_instance_images) |
106 | 108 | ||
@@ -123,10 +125,8 @@ class CSVDataset(Dataset): | |||
123 | ] | 125 | ] |
124 | ) | 126 | ) |
125 | 127 | ||
126 | self.cache = {} | ||
127 | |||
128 | def __len__(self): | 128 | def __len__(self): |
129 | return self._length | 129 | return math.ceil(self._length / self.batch_size) * self.batch_size |
130 | 130 | ||
131 | def get_example(self, i): | 131 | def get_example(self, i): |
132 | image_path, text = self.data[i % self.num_instance_images] | 132 | image_path, text = self.data[i % self.num_instance_images] |