From 306f2bfb620e6882737658bd3694c79365d75e4b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 15:23:40 +0200 Subject: Improved prompt handling --- data/csv.py | 83 +++++++++++++++++++++++++++---------------------------------- 1 file changed, 37 insertions(+), 46 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,11 +1,14 @@ import math import pandas as pd +import torch from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import NamedTuple, List +from typing import NamedTuple, List, Optional + +from models.clip.prompt import PromptProcessor class CSVDataItem(NamedTuple): @@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): class CSVDataModule(pl.LightningDataModule): def __init__( self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="cls", - num_class_images=100, - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, + batch_size: int, + data_file: str, + prompt_processor: PromptProcessor, + instance_identifier: str, + class_identifier: Optional[str] = None, + class_subdir: str = "cls", + num_class_images: int = 100, + size: int = 512, + repeats: int = 1, + interpolation: str = "bicubic", + center_crop: bool = False, + valid_set_size: Optional[int] = None, + generator: Optional[torch.Generator] = None, collate_fn=None ): super().__init__() @@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images - self.tokenizer = tokenizer + self.prompt_processor = prompt_processor self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size @@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root.joinpath(item.image), self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), item.prompt, - item.nprompt if "nprompt" in item else "" + item.nprompt ) for item in data for i in range(image_multiplier) @@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): self.data_val = self.prepare_subdata(data_val) def setup(self, stage=None): - train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, + train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, + val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) @@ -113,19 +116,19 @@ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], - tokenizer, - instance_identifier, - batch_size=1, - class_identifier=None, - num_class_images=0, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, + prompt_processor: PromptProcessor, + instance_identifier: str, + batch_size: int = 1, + class_identifier: Optional[str] = None, + num_class_images: int = 0, + size: int = 512, + repeats: int = 1, + interpolation: str = "bicubic", + center_crop: bool = False, ): self.data = data - self.tokenizer = tokenizer + self.prompt_processor = prompt_processor self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier @@ -163,12 +166,6 @@ class CSVDataset(Dataset): example = {} - if isinstance(item.prompt, str): - item.prompt = [item.prompt] - - if isinstance(item.nprompt, str): - item.nprompt = [item.nprompt] - example["prompts"] = item.prompt example["nprompts"] = item.nprompt @@ -181,12 +178,9 @@ class CSVDataset(Dataset): self.image_cache[item.instance_image_path] = instance_image example["instance_images"] = instance_image - example["instance_prompt_ids"] = self.tokenizer( - item.prompt.format(self.instance_identifier), - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + item.prompt.format(self.instance_identifier) + ) if self.num_class_images != 0: class_image = Image.open(item.class_image_path) @@ -194,12 +188,9 @@ class CSVDataset(Dataset): class_image = class_image.convert("RGB") example["class_images"] = class_image - example["class_prompt_ids"] = self.tokenizer( - item.prompt.format(self.class_identifier), - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids + example["class_prompt_ids"] = self.prompt_processor.get_input_ids( + item.nprompt.format(self.class_identifier) + ) self.cache[item.instance_image_path] = example return example -- cgit v1.2.3-70-g09d2