summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py
index df15c5a..5144c0a 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule):
99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
100 instance_identifier=self.instance_identifier, 100 instance_identifier=self.instance_identifier,
101 size=self.size, interpolation=self.interpolation, 101 size=self.size, interpolation=self.interpolation,
102 center_crop=self.center_crop, repeats=self.repeats) 102 center_crop=self.center_crop)
103 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 103 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
104 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 104 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
105 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 105 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
@@ -157,6 +157,17 @@ class CSVDataset(Dataset):
157 def __len__(self): 157 def __len__(self):
158 return math.ceil(self._length / self.batch_size) * self.batch_size 158 return math.ceil(self._length / self.batch_size) * self.batch_size
159 159
160 def get_image(self, path):
161 if path in self.image_cache:
162 return self.image_cache[path]
163
164 image = Image.open(path)
165 if not image.mode == "RGB":
166 image = image.convert("RGB")
167 self.image_cache[path] = image
168
169 return image
170
160 def get_example(self, i): 171 def get_example(self, i):
161 item = self.data[i % self.num_instance_images] 172 item = self.data[i % self.num_instance_images]
162 cache_key = f"{item.instance_image_path}_{item.class_image_path}" 173 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
@@ -169,30 +180,18 @@ class CSVDataset(Dataset):
169 example["prompts"] = item.prompt 180 example["prompts"] = item.prompt
170 example["nprompts"] = item.nprompt 181 example["nprompts"] = item.nprompt
171 182
172 if item.instance_image_path in self.image_cache: 183 example["instance_images"] = self.get_image(item.instance_image_path)
173 instance_image = self.image_cache[item.instance_image_path]
174 else:
175 instance_image = Image.open(item.instance_image_path)
176 if not instance_image.mode == "RGB":
177 instance_image = instance_image.convert("RGB")
178 self.image_cache[item.instance_image_path] = instance_image
179
180 example["instance_images"] = instance_image
181 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 184 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
182 item.prompt.format(self.instance_identifier) 185 item.prompt.format(self.instance_identifier)
183 ) 186 )
184 187
185 if self.num_class_images != 0: 188 if self.num_class_images != 0:
186 class_image = Image.open(item.class_image_path) 189 example["class_images"] = self.get_image(item.class_image_path)
187 if not class_image.mode == "RGB":
188 class_image = class_image.convert("RGB")
189
190 example["class_images"] = class_image
191 example["class_prompt_ids"] = self.prompt_processor.get_input_ids( 190 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(
192 item.nprompt.format(self.class_identifier) 191 item.nprompt.format(self.class_identifier)
193 ) 192 )
194 193
195 self.cache[item.instance_image_path] = example 194 self.cache[cache_key] = example
196 return example 195 return example
197 196
198 def __getitem__(self, i): 197 def __getitem__(self, i):
@@ -204,7 +203,7 @@ class CSVDataset(Dataset):
204 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 203 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
205 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 204 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
206 205
207 if self.class_identifier is not None: 206 if self.num_class_images != 0:
208 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 207 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
209 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 208 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
210 209