From 14beba63391e1ddc9a145bb638d9306086ad1a5c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 08:34:07 +0200 Subject: Training: Create multiple class images per training image --- data/csv.py | 54 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 14 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index abd329d..dcaf7d3 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,5 +1,3 @@ -import math -import os import pandas as pd from pathlib import Path import pytorch_lightning as pl @@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule): instance_identifier, class_identifier=None, class_subdir="db_cls", + num_class_images=2, size=512, repeats=100, interpolation="bicubic", @@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) + self.num_class_images = num_class_images self.tokenizer = tokenizer self.instance_identifier = instance_identifier @@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) - instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] - class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [(i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x"] + instance_image_paths = [ + self.data_root.joinpath(f) + for f in metadata['image'].values + for i in range(self.num_class_images) + ] + class_image_paths = [ + self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") + for f in metadata['image'].values + for i in range(self.num_class_images) + ] + prompts = [ + prompt + for prompt in metadata['prompt'].values + for i in range(self.num_class_images) + ] + nprompts = [ + nprompt + for nprompt in metadata['nprompt'].values + for i in range(self.num_class_images) + ] if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = [ + skip + for skip in metadata['skip'].values + for i in range(self.num_class_images) + ] if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [ + (i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x" + ] def setup(self, stage=None): valid_set_size = int(len(self.data) * 0.2) @@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule): train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, @@ -93,6 +116,7 @@ class CSVDataset(Dataset): tokenizer, instance_identifier, class_identifier=None, + num_class_images=2, size=512, repeats=1, interpolation="bicubic", @@ -103,6 +127,7 @@ class CSVDataset(Dataset): self.tokenizer = tokenizer self.instance_identifier = instance_identifier self.class_identifier = class_identifier + self.num_class_images = num_class_images self.cache = {} self.num_instance_images = len(self.data) @@ -128,9 +153,10 @@ class CSVDataset(Dataset): def get_example(self, i): instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] + cache_key = f"{instance_image_path}_{class_image_path}" - if instance_image_path in self.cache: - return self.cache[instance_image_path] + if cache_key in self.cache: + return self.cache[cache_key] example = {} @@ -149,7 +175,7 @@ class CSVDataset(Dataset): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_identifier is not None: + if self.num_class_images != 0: class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") -- cgit v1.2.3-70-g09d2