From dfc51d6d74410acefab86d2938a2b864be603668 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 09:35:57 +0100 Subject: Update --- data/csv.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index b45ac77..0810c2c 100644 --- a/data/csv.py +++ b/data/csv.py @@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] - cprompt: list[str] - nprompt: list[str] + cprompt: str + nprompt: str class CSVDataModule(): @@ -105,14 +105,14 @@ class CSVDataModule(): prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), - prompt_to_keywords( + keywords_to_prompt(prompt_to_keywords( cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions - ), - prompt_to_keywords( + )), + keywords_to_prompt(prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions - ), + )), ) for item in data ] @@ -261,8 +261,8 @@ class CSVDataset(Dataset): example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) - example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) - example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) + example["cprompts"] = unprocessed_example["cprompts"] + example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) -- cgit v1.2.3-54-g00ecf