diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 22 | ||||
-rw-r--r-- | data/keywords.py | 21 |
2 files changed, 22 insertions, 21 deletions
diff --git a/data/csv.py b/data/csv.py index a60733a..d1f3054 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,7 +1,6 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import numpy as np | ||
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | from PIL import Image | 5 | from PIL import Image |
7 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
@@ -9,32 +8,13 @@ from torchvision import transforms | |||
9 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
10 | 9 | ||
11 | from models.clip.prompt import PromptProcessor | 10 | from models.clip.prompt import PromptProcessor |
11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | ||
12 | 12 | ||
13 | 13 | ||
14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): |
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 15 | return {"content": prompt} if isinstance(prompt, str) else prompt |
16 | 16 | ||
17 | 17 | ||
18 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | ||
19 | if dropout != 0: | ||
20 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | ||
21 | if shuffle: | ||
22 | np.random.shuffle(prompt) | ||
23 | return ", ".join(prompt) | ||
24 | |||
25 | |||
26 | def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: | ||
27 | def expand_keyword(keyword: str) -> list[str]: | ||
28 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
29 | |||
30 | return [ | ||
31 | kw | ||
32 | for keyword in prompt.split(", ") | ||
33 | for kw in expand_keyword(keyword) | ||
34 | if keyword != "" | ||
35 | ] | ||
36 | |||
37 | |||
38 | class CSVDataItem(NamedTuple): | 18 | class CSVDataItem(NamedTuple): |
39 | instance_image_path: Path | 19 | instance_image_path: Path |
40 | class_image_path: Path | 20 | class_image_path: Path |
diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py | |||
@@ -0,0 +1,21 @@ | |||
1 | import numpy as np | ||
2 | |||
3 | |||
4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | ||
5 | if dropout != 0: | ||
6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | ||
7 | if shuffle: | ||
8 | np.random.shuffle(prompt) | ||
9 | return ", ".join(prompt) | ||
10 | |||
11 | |||
12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | ||
13 | def expand_keyword(keyword: str) -> list[str]: | ||
14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
15 | |||
16 | return [ | ||
17 | kw | ||
18 | for keyword in prompt.split(", ") | ||
19 | for kw in expand_keyword(keyword) | ||
20 | if keyword != "" | ||
21 | ] | ||