diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | import numpy as np | ||
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| 5 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
| 6 | from PIL import Image | 7 | from PIL import Image |
| @@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 16 | 17 | ||
| 17 | 18 | ||
| 19 | def shuffle_prompt(prompt: str): | ||
| 20 | def handle_block(block: str): | ||
| 21 | words = block.split(", ") | ||
| 22 | np.random.shuffle(words) | ||
| 23 | return ", ".join(words) | ||
| 24 | |||
| 25 | prompt = prompt.split(". ") | ||
| 26 | prompt = [handle_block(b) for b in prompt] | ||
| 27 | np.random.shuffle(prompt) | ||
| 28 | prompt = ". ".join(prompt) | ||
| 29 | return prompt | ||
| 30 | |||
| 31 | |||
| 18 | class CSVDataItem(NamedTuple): | 32 | class CSVDataItem(NamedTuple): |
| 19 | instance_image_path: Path | 33 | instance_image_path: Path |
| 20 | class_image_path: Path | 34 | class_image_path: Path |
| @@ -190,30 +204,27 @@ class CSVDataset(Dataset): | |||
| 190 | item = self.data[i % self.num_instance_images] | 204 | item = self.data[i % self.num_instance_images] |
| 191 | 205 | ||
| 192 | example = {} | 206 | example = {} |
| 193 | |||
| 194 | example["prompts"] = item.prompt | 207 | example["prompts"] = item.prompt |
| 195 | example["nprompts"] = item.nprompt | 208 | example["nprompts"] = item.nprompt |
| 196 | |||
| 197 | example["instance_images"] = self.get_image(item.instance_image_path) | 209 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 198 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) | ||
| 199 | |||
| 200 | if self.num_class_images != 0: | 210 | if self.num_class_images != 0: |
| 201 | example["class_images"] = self.get_image(item.class_image_path) | 211 | example["class_images"] = self.get_image(item.class_image_path) |
| 202 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) | ||
| 203 | 212 | ||
| 204 | return example | 213 | return example |
| 205 | 214 | ||
| 206 | def __getitem__(self, i): | 215 | def __getitem__(self, i): |
| 207 | example = {} | ||
| 208 | unprocessed_example = self.get_example(i) | 216 | unprocessed_example = self.get_example(i) |
| 209 | 217 | ||
| 210 | example["prompts"] = unprocessed_example["prompts"] | 218 | example = {} |
| 219 | |||
| 220 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | ||
| 211 | example["nprompts"] = unprocessed_example["nprompts"] | 221 | example["nprompts"] = unprocessed_example["nprompts"] |
| 222 | |||
| 212 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 223 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 213 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 224 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) |
| 214 | 225 | ||
| 215 | if self.num_class_images != 0: | 226 | if self.num_class_images != 0: |
| 216 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 227 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 217 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 228 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) |
| 218 | 229 | ||
| 219 | return example | 230 | return example |
