diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index df15c5a..5144c0a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
100 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
101 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
102 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop) |
103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
@@ -157,6 +157,17 @@ class CSVDataset(Dataset): | |||
157 | def __len__(self): | 157 | def __len__(self): |
158 | return math.ceil(self._length / self.batch_size) * self.batch_size | 158 | return math.ceil(self._length / self.batch_size) * self.batch_size |
159 | 159 | ||
160 | def get_image(self, path): | ||
161 | if path in self.image_cache: | ||
162 | return self.image_cache[path] | ||
163 | |||
164 | image = Image.open(path) | ||
165 | if not image.mode == "RGB": | ||
166 | image = image.convert("RGB") | ||
167 | self.image_cache[path] = image | ||
168 | |||
169 | return image | ||
170 | |||
160 | def get_example(self, i): | 171 | def get_example(self, i): |
161 | item = self.data[i % self.num_instance_images] | 172 | item = self.data[i % self.num_instance_images] |
162 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | 173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
@@ -169,30 +180,18 @@ class CSVDataset(Dataset): | |||
169 | example["prompts"] = item.prompt | 180 | example["prompts"] = item.prompt |
170 | example["nprompts"] = item.nprompt | 181 | example["nprompts"] = item.nprompt |
171 | 182 | ||
172 | if item.instance_image_path in self.image_cache: | 183 | example["instance_images"] = self.get_image(item.instance_image_path) |
173 | instance_image = self.image_cache[item.instance_image_path] | ||
174 | else: | ||
175 | instance_image = Image.open(item.instance_image_path) | ||
176 | if not instance_image.mode == "RGB": | ||
177 | instance_image = instance_image.convert("RGB") | ||
178 | self.image_cache[item.instance_image_path] = instance_image | ||
179 | |||
180 | example["instance_images"] = instance_image | ||
181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
182 | item.prompt.format(self.instance_identifier) | 185 | item.prompt.format(self.instance_identifier) |
183 | ) | 186 | ) |
184 | 187 | ||
185 | if self.num_class_images != 0: | 188 | if self.num_class_images != 0: |
186 | class_image = Image.open(item.class_image_path) | 189 | example["class_images"] = self.get_image(item.class_image_path) |
187 | if not class_image.mode == "RGB": | ||
188 | class_image = class_image.convert("RGB") | ||
189 | |||
190 | example["class_images"] = class_image | ||
191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
192 | item.nprompt.format(self.class_identifier) | 191 | item.nprompt.format(self.class_identifier) |
193 | ) | 192 | ) |
194 | 193 | ||
195 | self.cache[item.instance_image_path] = example | 194 | self.cache[cache_key] = example |
196 | return example | 195 | return example |
197 | 196 | ||
198 | def __getitem__(self, i): | 197 | def __getitem__(self, i): |
@@ -204,7 +203,7 @@ class CSVDataset(Dataset): | |||
204 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 203 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
205 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 204 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
206 | 205 | ||
207 | if self.class_identifier is not None: | 206 | if self.num_class_images != 0: |
208 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 207 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
209 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 208 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
210 | 209 | ||