From 1bd386f98bb076fe62696808e02a9bd9b9b64b42 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:47:12 +0100 Subject: Improved class prompt handling --- data/csv.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py @@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] + cprompt: str nprompt: str @@ -47,8 +48,6 @@ class CSVDataModule(): batch_size: int, data_file: str, prompt_processor: PromptProcessor, - instance_identifier: str, - class_identifier: Optional[str] = None, class_subdir: str = "cls", num_class_images: int = 100, size: int = 512, @@ -77,8 +76,6 @@ class CSVDataModule(): self.num_class_images = num_class_images self.prompt_processor = prompt_processor - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier self.size = size self.repeats = repeats self.dropout = dropout @@ -96,14 +93,18 @@ class CSVDataModule(): def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" + cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ CSVDataItem( self.data_root.joinpath(image.format(item["image"])), None, - prompt_to_keywords(prompt.format( - **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), + prompt_to_keywords( + prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), + expansions + ), + cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data @@ -123,6 +124,7 @@ class CSVDataModule(): item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, + item.cprompt, item.nprompt, ) for item in items @@ -160,12 +162,10 @@ class CSVDataModule(): def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, - instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, - instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, @@ -187,9 +187,7 @@ class CSVDataset(Dataset): self, data: List[CSVDataItem], prompt_processor: PromptProcessor, - instance_identifier: str, batch_size: int = 1, - class_identifier: Optional[str] = None, num_class_images: int = 0, size: int = 512, repeats: int = 1, @@ -201,8 +199,6 @@ class CSVDataset(Dataset): self.data = data self.prompt_processor = prompt_processor self.batch_size = batch_size - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier self.num_class_images = num_class_images self.dropout = dropout self.image_cache = {} @@ -239,14 +235,12 @@ class CSVDataset(Dataset): return image - def get_input_ids(self, prompt, identifier): - return self.prompt_processor.get_input_ids(prompt.format(identifier)) - def get_example(self, i): item = self.data[i % self.num_instance_images] example = {} example["prompts"] = item.prompt + example["cprompts"] = item.cprompt example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) if self.num_class_images != 0: @@ -260,13 +254,14 @@ class CSVDataset(Dataset): example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) + example["cprompts"] = unprocessed_example["cprompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) return example -- cgit v1.2.3-70-g09d2