From baba91864a45939cef4f77f6ca96ade7ae5ef274 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 23:46:18 +0200 Subject: Advanced datasets --- data/csv.py | 64 +++++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 25 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 5144c0a..f9b5e39 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,16 +1,20 @@ import math -import pandas as pd import torch +import json from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import NamedTuple, List, Optional +from typing import Dict, NamedTuple, List, Optional, Union from models.clip.prompt import PromptProcessor +def prepare_prompt(prompt: Union[str, Dict[str, str]]): + return {"content": prompt} if isinstance(prompt, str) else prompt + + class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule): self.collate_fn = collate_fn self.batch_size = batch_size - def prepare_subdata(self, data, num_class_images=1): + def prepare_subdata(self, template, data, num_class_images=1): + image = template["image"] if "image" in template else "{}" + prompt = template["prompt"] if "prompt" in template else "{content}" + nprompt = template["nprompt"] if "nprompt" in template else "{content}" + image_multiplier = max(math.ceil(num_class_images / len(data)), 1) return [ CSVDataItem( - self.data_root.joinpath(item.image), - self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), - item.prompt, - item.nprompt + self.data_root.joinpath(image.format(item["image"])), + self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), + prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), + nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) ) for item in data for i in range(image_multiplier) ] def prepare_data(self): - metadata = pd.read_json(self.data_file) - metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] - num_images = len(metadata) + with open(self.data_file, 'rt') as f: + metadata = json.load(f) + template = metadata["template"] if "template" in metadata else {} + items = metadata["items"] if "items" in metadata else [] + + items = [item for item in items if not "skip" in item or item["skip"] != True] + num_images = len(items) valid_set_size = int(num_images * 0.2) if self.valid_set_size: @@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule): valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size - data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) + data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) - self.data_train = self.prepare_subdata(data_train, self.num_class_images) - self.data_val = self.prepare_subdata(data_val) + self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) + self.data_val = self.prepare_subdata(template, data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, @@ -133,8 +145,8 @@ class CSVDataset(Dataset): self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images - self.cache = {} self.image_cache = {} + self.input_id_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -168,12 +180,19 @@ class CSVDataset(Dataset): return image + def get_input_ids(self, prompt, identifier): + prompt = prompt.format(identifier) + + if prompt in self.input_id_cache: + return self.input_id_cache[prompt] + + input_ids = self.prompt_processor.get_input_ids(prompt) + self.input_id_cache[prompt] = input_ids + + return input_ids + def get_example(self, i): item = self.data[i % self.num_instance_images] - cache_key = f"{item.instance_image_path}_{item.class_image_path}" - - if cache_key in self.cache: - return self.cache[cache_key] example = {} @@ -181,17 +200,12 @@ class CSVDataset(Dataset): example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( - item.prompt.format(self.instance_identifier) - ) + example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids( - item.nprompt.format(self.class_identifier) - ) + example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) - self.cache[cache_key] = example return example def __getitem__(self, i): -- cgit v1.2.3-70-g09d2