summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-02 11:14:03 +0200
committerVolpeon <git@volpeon.ink>2023-04-02 11:14:03 +0200
commite3669927b47b5367a3348d30c4b318da84af661d (patch)
treea9740db8cea9149eaab99f08f7cb8778f8e643b7 /data/keywords.py
parentUpdate (diff)
downloadtextual-inversion-diff-e3669927b47b5367a3348d30c4b318da84af661d.tar.gz
textual-inversion-diff-e3669927b47b5367a3348d30c4b318da84af661d.tar.bz2
textual-inversion-diff-e3669927b47b5367a3348d30c4b318da84af661d.zip
Update dataset format: Separate prompt and keywords
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/data/keywords.py b/data/keywords.py
index 9e656f3..7385809 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,21 +1,21 @@
1import numpy as np 1import numpy as np
2 2
3 3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: 4def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0: 5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout] 6 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 if shuffle: 7 if shuffle:
8 np.random.shuffle(prompt) 8 np.random.shuffle(keywords)
9 return ", ".join(prompt) 9 return ", ".join(keywords)
10 10
11 11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: 12def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]: 13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] 14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15 15
16 return [ 16 return [
17 kw 17 kw
18 for keyword in prompt.split(", ") 18 for keyword in s.split(", ")
19 for kw in expand_keyword(keyword) 19 for kw in expand_keyword(keyword)
20 if keyword != "" 20 if keyword != ""
21 ] 21 ]