diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index df15c5a..5144c0a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 100 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
| 101 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
| 102 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop) |
| 103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| @@ -157,6 +157,17 @@ class CSVDataset(Dataset): | |||
| 157 | def __len__(self): | 157 | def __len__(self): |
| 158 | return math.ceil(self._length / self.batch_size) * self.batch_size | 158 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 159 | 159 | ||
| 160 | def get_image(self, path): | ||
| 161 | if path in self.image_cache: | ||
| 162 | return self.image_cache[path] | ||
| 163 | |||
| 164 | image = Image.open(path) | ||
| 165 | if not image.mode == "RGB": | ||
| 166 | image = image.convert("RGB") | ||
| 167 | self.image_cache[path] = image | ||
| 168 | |||
| 169 | return image | ||
| 170 | |||
| 160 | def get_example(self, i): | 171 | def get_example(self, i): |
| 161 | item = self.data[i % self.num_instance_images] | 172 | item = self.data[i % self.num_instance_images] |
| 162 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | 173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
| @@ -169,30 +180,18 @@ class CSVDataset(Dataset): | |||
| 169 | example["prompts"] = item.prompt | 180 | example["prompts"] = item.prompt |
| 170 | example["nprompts"] = item.nprompt | 181 | example["nprompts"] = item.nprompt |
| 171 | 182 | ||
| 172 | if item.instance_image_path in self.image_cache: | 183 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 173 | instance_image = self.image_cache[item.instance_image_path] | ||
| 174 | else: | ||
| 175 | instance_image = Image.open(item.instance_image_path) | ||
| 176 | if not instance_image.mode == "RGB": | ||
| 177 | instance_image = instance_image.convert("RGB") | ||
| 178 | self.image_cache[item.instance_image_path] = instance_image | ||
| 179 | |||
| 180 | example["instance_images"] = instance_image | ||
| 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 182 | item.prompt.format(self.instance_identifier) | 185 | item.prompt.format(self.instance_identifier) |
| 183 | ) | 186 | ) |
| 184 | 187 | ||
| 185 | if self.num_class_images != 0: | 188 | if self.num_class_images != 0: |
| 186 | class_image = Image.open(item.class_image_path) | 189 | example["class_images"] = self.get_image(item.class_image_path) |
| 187 | if not class_image.mode == "RGB": | ||
| 188 | class_image = class_image.convert("RGB") | ||
| 189 | |||
| 190 | example["class_images"] = class_image | ||
| 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 192 | item.nprompt.format(self.class_identifier) | 191 | item.nprompt.format(self.class_identifier) |
| 193 | ) | 192 | ) |
| 194 | 193 | ||
| 195 | self.cache[item.instance_image_path] = example | 194 | self.cache[cache_key] = example |
| 196 | return example | 195 | return example |
| 197 | 196 | ||
| 198 | def __getitem__(self, i): | 197 | def __getitem__(self, i): |
| @@ -204,7 +203,7 @@ class CSVDataset(Dataset): | |||
| 204 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 203 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 205 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 204 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 206 | 205 | ||
| 207 | if self.class_identifier is not None: | 206 | if self.num_class_images != 0: |
| 208 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 207 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 209 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 208 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
| 210 | 209 | ||
