summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py
index aad970c..316c099 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule):
72 ] 72 ]
73 73
74 def prepare_data(self): 74 def prepare_data(self):
75 metadata = pd.read_csv(self.data_file) 75 metadata = pd.read_json(self.data_file)
76 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] 76 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True]
77 num_images = len(metadata) 77 num_images = len(metadata)
78 78
79 valid_set_size = int(num_images * 0.2) 79 valid_set_size = int(num_images * 0.2)
@@ -163,6 +163,12 @@ class CSVDataset(Dataset):
163 163
164 example = {} 164 example = {}
165 165
166 if isinstance(item.prompt, str):
167 item.prompt = [item.prompt]
168
169 if isinstance(item.nprompt, str):
170 item.nprompt = [item.nprompt]
171
166 example["prompts"] = item.prompt 172 example["prompts"] = item.prompt
167 example["nprompts"] = item.nprompt 173 example["nprompts"] = item.nprompt
168 174
@@ -177,7 +183,7 @@ class CSVDataset(Dataset):
177 example["instance_images"] = instance_image 183 example["instance_images"] = instance_image
178 example["instance_prompt_ids"] = self.tokenizer( 184 example["instance_prompt_ids"] = self.tokenizer(
179 item.prompt.format(self.instance_identifier), 185 item.prompt.format(self.instance_identifier),
180 padding="do_not_pad", 186 padding="max_length",
181 truncation=True, 187 truncation=True,
182 max_length=self.tokenizer.model_max_length, 188 max_length=self.tokenizer.model_max_length,
183 ).input_ids 189 ).input_ids
@@ -190,7 +196,7 @@ class CSVDataset(Dataset):
190 example["class_images"] = class_image 196 example["class_images"] = class_image
191 example["class_prompt_ids"] = self.tokenizer( 197 example["class_prompt_ids"] = self.tokenizer(
192 item.prompt.format(self.class_identifier), 198 item.prompt.format(self.class_identifier),
193 padding="do_not_pad", 199 padding="max_length",
194 truncation=True, 200 truncation=True,
195 max_length=self.tokenizer.model_max_length, 201 max_length=self.tokenizer.model_max_length,
196 ).input_ids 202 ).input_ids