From 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Oct 2022 21:56:54 +0200 Subject: Update --- data/csv.py | 162 +++++++++++++++++++++++++++++++----------------------------- 1 file changed, 85 insertions(+), 77 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,27 +1,38 @@ +import math import pandas as pd from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms +from typing import NamedTuple, List + + +class CSVDataItem(NamedTuple): + instance_image_path: Path + class_image_path: Path + prompt: str + nprompt: str class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="db_cls", - num_class_images=2, - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, - collate_fn=None): + def __init__( + self, + batch_size, + data_file, + tokenizer, + instance_identifier, + class_identifier=None, + class_subdir="db_cls", + num_class_images=100, + size=512, + repeats=100, + interpolation="bicubic", + center_crop=False, + valid_set_size=None, + generator=None, + collate_fn=None + ): super().__init__() self.data_file = Path(data_file) @@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): self.collate_fn = collate_fn self.batch_size = batch_size + def prepare_subdata(self, data, num_class_images=1): + image_multiplier = max(math.ceil(num_class_images / len(data)), 1) + + return [ + CSVDataItem( + self.data_root.joinpath(item.image), + self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), + item.prompt, + item.nprompt if "nprompt" in item else "" + ) + for item in data + if "skip" not in item or item.skip != "x" + for i in range(image_multiplier) + ] + def prepare_data(self): metadata = pd.read_csv(self.data_file) - instance_image_paths = [ - self.data_root.joinpath(f) - for f in metadata['image'].values - for i in range(self.num_class_images) - ] - class_image_paths = [ - self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") - for f in metadata['image'].values - for i in range(self.num_class_images) - ] - prompts = [ - prompt - for prompt in metadata['prompt'].values - for i in range(self.num_class_images) - ] - nprompts = [ - nprompt - for nprompt in metadata['nprompt'].values - for i in range(self.num_class_images) - ] if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = [ - skip - for skip in metadata['skip'].values - for i in range(self.num_class_images) - ] if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [ - (i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x" - ] + metadata = list(metadata.itertuples()) + num_images = len(metadata) - def setup(self, stage=None): - valid_set_size = int(len(self.data) * 0.2) + valid_set_size = int(num_images * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data) - valid_set_size + train_set_size = num_images - valid_set_size - self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) - train_dataset = CSVDataset(self.data_train, self.tokenizer, + self.data_train = self.prepare_subdata(data_train, self.num_class_images) + self.data_val = self.prepare_subdata(data_val) + + def setup(self, stage=None): + train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, + val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): @@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - instance_identifier, - class_identifier=None, - num_class_images=2, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, - ): + def __init__( + self, + data: List[CSVDataItem], + tokenizer, + instance_identifier, + batch_size=1, + class_identifier=None, + num_class_images=0, + size=512, + repeats=1, + interpolation="bicubic", + center_crop=False, + ): self.data = data self.tokenizer = tokenizer + self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.cache = {} + self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -149,46 +153,50 @@ class CSVDataset(Dataset): ) def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): - instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - cache_key = f"{instance_image_path}_{class_image_path}" + item = self.data[i % self.num_instance_images] + cache_key = f"{item.instance_image_path}_{item.class_image_path}" if cache_key in self.cache: return self.cache[cache_key] example = {} - example["prompts"] = prompt - example["nprompts"] = nprompt + example["prompts"] = item.prompt + example["nprompts"] = item.nprompt - instance_image = Image.open(instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") + if item.instance_image_path in self.image_cache: + instance_image = self.image_cache[item.instance_image_path] + else: + instance_image = Image.open(item.instance_image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + self.image_cache[item.instance_image_path] = instance_image example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - prompt.format(self.instance_identifier), + item.prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.num_class_images != 0: - class_image = Image.open(class_image_path) + class_image = Image.open(item.class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - prompt.format(self.class_identifier), + item.prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - self.cache[instance_image_path] = example + self.cache[item.instance_image_path] = example return example def __getitem__(self, i): -- cgit v1.2.3-54-g00ecf