From 847ec3b6c43c89ef3649715f86ecfed370b6e442 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 07:34:30 +0200 Subject: Update --- data/csv.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index df15c5a..5144c0a 100644 --- a/data/csv.py +++ b/data/csv.py @@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule): val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, @@ -157,6 +157,17 @@ class CSVDataset(Dataset): def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size + def get_image(self, path): + if path in self.image_cache: + return self.image_cache[path] + + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + self.image_cache[path] = image + + return image + def get_example(self, i): item = self.data[i % self.num_instance_images] cache_key = f"{item.instance_image_path}_{item.class_image_path}" @@ -169,30 +180,18 @@ class CSVDataset(Dataset): example["prompts"] = item.prompt example["nprompts"] = item.nprompt - if item.instance_image_path in self.image_cache: - instance_image = self.image_cache[item.instance_image_path] - else: - instance_image = Image.open(item.instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - self.image_cache[item.instance_image_path] = instance_image - - example["instance_images"] = instance_image + example["instance_images"] = self.get_image(item.instance_image_path) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( item.prompt.format(self.instance_identifier) ) if self.num_class_images != 0: - class_image = Image.open(item.class_image_path) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - - example["class_images"] = class_image + example["class_images"] = self.get_image(item.class_image_path) example["class_prompt_ids"] = self.prompt_processor.get_input_ids( item.nprompt.format(self.class_identifier) ) - self.cache[item.instance_image_path] = example + self.cache[cache_key] = example return example def __getitem__(self, i): @@ -204,7 +203,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_identifier is not None: + if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] -- cgit v1.2.3-70-g09d2