summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 13:26:32 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 13:26:32 +0100
commit3396ca881ed3f3521617cd9024eea56975191d32 (patch)
tree3189c3bbe77b211152d11b524d0fe3a7016441ee /data/keywords.py
parentFix (diff)
downloadtextual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.gz
textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.bz2
textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.zip
Update
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/data/keywords.py b/data/keywords.py
new file mode 100644
index 0000000..9e656f3
--- /dev/null
+++ b/data/keywords.py
@@ -0,0 +1,21 @@
1import numpy as np
2
3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
7 if shuffle:
8 np.random.shuffle(prompt)
9 return ", ".join(prompt)
10
11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15
16 return [
17 kw
18 for keyword in prompt.split(", ")
19 for kw in expand_keyword(keyword)
20 if keyword != ""
21 ]