diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import numpy as np | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
6 | from PIL import Image | 7 | from PIL import Image |
@@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
16 | 17 | ||
17 | 18 | ||
19 | def shuffle_prompt(prompt: str): | ||
20 | def handle_block(block: str): | ||
21 | words = block.split(", ") | ||
22 | np.random.shuffle(words) | ||
23 | return ", ".join(words) | ||
24 | |||
25 | prompt = prompt.split(". ") | ||
26 | prompt = [handle_block(b) for b in prompt] | ||
27 | np.random.shuffle(prompt) | ||
28 | prompt = ". ".join(prompt) | ||
29 | return prompt | ||
30 | |||
31 | |||
18 | class CSVDataItem(NamedTuple): | 32 | class CSVDataItem(NamedTuple): |
19 | instance_image_path: Path | 33 | instance_image_path: Path |
20 | class_image_path: Path | 34 | class_image_path: Path |
@@ -190,30 +204,27 @@ class CSVDataset(Dataset): | |||
190 | item = self.data[i % self.num_instance_images] | 204 | item = self.data[i % self.num_instance_images] |
191 | 205 | ||
192 | example = {} | 206 | example = {} |
193 | |||
194 | example["prompts"] = item.prompt | 207 | example["prompts"] = item.prompt |
195 | example["nprompts"] = item.nprompt | 208 | example["nprompts"] = item.nprompt |
196 | |||
197 | example["instance_images"] = self.get_image(item.instance_image_path) | 209 | example["instance_images"] = self.get_image(item.instance_image_path) |
198 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) | ||
199 | |||
200 | if self.num_class_images != 0: | 210 | if self.num_class_images != 0: |
201 | example["class_images"] = self.get_image(item.class_image_path) | 211 | example["class_images"] = self.get_image(item.class_image_path) |
202 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) | ||
203 | 212 | ||
204 | return example | 213 | return example |
205 | 214 | ||
206 | def __getitem__(self, i): | 215 | def __getitem__(self, i): |
207 | example = {} | ||
208 | unprocessed_example = self.get_example(i) | 216 | unprocessed_example = self.get_example(i) |
209 | 217 | ||
210 | example["prompts"] = unprocessed_example["prompts"] | 218 | example = {} |
219 | |||
220 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | ||
211 | example["nprompts"] = unprocessed_example["nprompts"] | 221 | example["nprompts"] = unprocessed_example["nprompts"] |
222 | |||
212 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 223 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
213 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 224 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) |
214 | 225 | ||
215 | if self.num_class_images != 0: | 226 | if self.num_class_images != 0: |
216 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 227 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
217 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 228 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) |
218 | 229 | ||
219 | return example | 230 | return example |