summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py22
-rw-r--r--data/keywords.py21
2 files changed, 22 insertions, 21 deletions
diff --git a/data/csv.py b/data/csv.py
index a60733a..d1f3054 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,7 +1,6 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import numpy as np
5from pathlib import Path 4from pathlib import Path
6from PIL import Image 5from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split 6from torch.utils.data import Dataset, DataLoader, random_split
@@ -9,32 +8,13 @@ from torchvision import transforms
9from typing import Dict, NamedTuple, List, Optional, Union, Callable 8from typing import Dict, NamedTuple, List, Optional, Union, Callable
10 9
11from models.clip.prompt import PromptProcessor 10from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt
12 12
13 13
14def prepare_prompt(prompt: Union[str, Dict[str, str]]): 14def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 15 return {"content": prompt} if isinstance(prompt, str) else prompt
16 16
17 17
18def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
19 if dropout != 0:
20 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
21 if shuffle:
22 np.random.shuffle(prompt)
23 return ", ".join(prompt)
24
25
26def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]:
27 def expand_keyword(keyword: str) -> list[str]:
28 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
29
30 return [
31 kw
32 for keyword in prompt.split(", ")
33 for kw in expand_keyword(keyword)
34 if keyword != ""
35 ]
36
37
38class CSVDataItem(NamedTuple): 18class CSVDataItem(NamedTuple):
39 instance_image_path: Path 19 instance_image_path: Path
40 class_image_path: Path 20 class_image_path: Path
diff --git a/data/keywords.py b/data/keywords.py
new file mode 100644
index 0000000..9e656f3
--- /dev/null
+++ b/data/keywords.py
@@ -0,0 +1,21 @@
1import numpy as np
2
3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
7 if shuffle:
8 np.random.shuffle(prompt)
9 return ", ".join(prompt)
10
11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15
16 return [
17 kw
18 for keyword in prompt.split(", ")
19 for kw in expand_keyword(keyword)
20 if keyword != ""
21 ]