From 728dfcf57c30f40236b3a00d7380c4e0057cacb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 22:08:58 +0200 Subject: Implemented extended prompt limit --- data/csv.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index aad970c..316c099 100644 --- a/data/csv.py +++ b/data/csv.py @@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule): ] def prepare_data(self): - metadata = pd.read_csv(self.data_file) - metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] + metadata = pd.read_json(self.data_file) + metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] num_images = len(metadata) valid_set_size = int(num_images * 0.2) @@ -163,6 +163,12 @@ class CSVDataset(Dataset): example = {} + if isinstance(item.prompt, str): + item.prompt = [item.prompt] + + if isinstance(item.nprompt, str): + item.nprompt = [item.nprompt] + example["prompts"] = item.prompt example["nprompts"] = item.nprompt @@ -177,7 +183,7 @@ class CSVDataset(Dataset): example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( item.prompt.format(self.instance_identifier), - padding="do_not_pad", + padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids @@ -190,7 +196,7 @@ class CSVDataset(Dataset): example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( item.prompt.format(self.class_identifier), - padding="do_not_pad", + padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids -- cgit v1.2.3-70-g09d2