From b1aa5872460930840e59d328ab3cfacae09d9427 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 22:23:02 +0100 Subject: Fix --- data/csv.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index d1f3054..4986153 100644 --- a/data/csv.py +++ b/data/csv.py @@ -246,12 +246,14 @@ class CSVDataset(Dataset): example = {} - example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) + example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) example["cprompts"] = unprocessed_example["cprompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) + ) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) -- cgit v1.2.3-54-g00ecf