summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /data/keywords.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/data/keywords.py b/data/keywords.py
index 8632d67..83fe9ff 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -8,7 +8,7 @@ def keywords_to_str(
8 undroppable_keywords: list[str] = [], 8 undroppable_keywords: list[str] = [],
9 dropout: float = 0, 9 dropout: float = 0,
10 shuffle: bool = False, 10 shuffle: bool = False,
11 npgenerator: Optional[np.random.Generator] = None 11 npgenerator: Optional[np.random.Generator] = None,
12) -> str: 12) -> str:
13 if dropout != 0: 13 if dropout != 0:
14 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 14 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
@@ -23,7 +23,11 @@ def keywords_to_str(
23 23
24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: 24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
25 def expand_keyword(keyword: str) -> list[str]: 25 def expand_keyword(keyword: str) -> list[str]:
26 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] 26 return (
27 [keyword] + expansions[keyword].split(", ")
28 if keyword in expansions
29 else [keyword]
30 )
27 31
28 return [ 32 return [
29 kw 33 kw