summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py29
1 files changed, 20 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py
index 67ac43b..23b5299 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,6 +1,7 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import numpy as np
4from pathlib import Path 5from pathlib import Path
5import pytorch_lightning as pl 6import pytorch_lightning as pl
6from PIL import Image 7from PIL import Image
@@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 16 return {"content": prompt} if isinstance(prompt, str) else prompt
16 17
17 18
19def shuffle_prompt(prompt: str):
20 def handle_block(block: str):
21 words = block.split(", ")
22 np.random.shuffle(words)
23 return ", ".join(words)
24
25 prompt = prompt.split(". ")
26 prompt = [handle_block(b) for b in prompt]
27 np.random.shuffle(prompt)
28 prompt = ". ".join(prompt)
29 return prompt
30
31
18class CSVDataItem(NamedTuple): 32class CSVDataItem(NamedTuple):
19 instance_image_path: Path 33 instance_image_path: Path
20 class_image_path: Path 34 class_image_path: Path
@@ -190,30 +204,27 @@ class CSVDataset(Dataset):
190 item = self.data[i % self.num_instance_images] 204 item = self.data[i % self.num_instance_images]
191 205
192 example = {} 206 example = {}
193
194 example["prompts"] = item.prompt 207 example["prompts"] = item.prompt
195 example["nprompts"] = item.nprompt 208 example["nprompts"] = item.nprompt
196
197 example["instance_images"] = self.get_image(item.instance_image_path) 209 example["instance_images"] = self.get_image(item.instance_image_path)
198 example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier)
199
200 if self.num_class_images != 0: 210 if self.num_class_images != 0:
201 example["class_images"] = self.get_image(item.class_image_path) 211 example["class_images"] = self.get_image(item.class_image_path)
202 example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier)
203 212
204 return example 213 return example
205 214
206 def __getitem__(self, i): 215 def __getitem__(self, i):
207 example = {}
208 unprocessed_example = self.get_example(i) 216 unprocessed_example = self.get_example(i)
209 217
210 example["prompts"] = unprocessed_example["prompts"] 218 example = {}
219
220 example["prompts"] = shuffle_prompt(unprocessed_example["prompts"])
211 example["nprompts"] = unprocessed_example["nprompts"] 221 example["nprompts"] = unprocessed_example["nprompts"]
222
212 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 223 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
213 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 224 example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier)
214 225
215 if self.num_class_images != 0: 226 if self.num_class_images != 0:
216 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 227 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
217 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 228 example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier)
218 229
219 return example 230 return example