summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 09:35:57 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 09:35:57 +0100
commitdfc51d6d74410acefab86d2938a2b864be603668 (patch)
treebed57fcd481bc243324950f93e890294703533f6 /data
parentFix (diff)
downloadtextual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.gz
textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.bz2
textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.zip
Update
Diffstat (limited to 'data')
-rw-r--r--data/csv.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py
index b45ac77..0810c2c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple):
39 instance_image_path: Path 39 instance_image_path: Path
40 class_image_path: Path 40 class_image_path: Path
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: list[str] 42 cprompt: str
43 nprompt: list[str] 43 nprompt: str
44 44
45 45
46class CSVDataModule(): 46class CSVDataModule():
@@ -105,14 +105,14 @@ class CSVDataModule():
105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
106 expansions 106 expansions
107 ), 107 ),
108 prompt_to_keywords( 108 keywords_to_prompt(prompt_to_keywords(
109 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 109 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
110 expansions 110 expansions
111 ), 111 )),
112 prompt_to_keywords( 112 keywords_to_prompt(prompt_to_keywords(
113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions 114 expansions
115 ), 115 )),
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -261,8 +261,8 @@ class CSVDataset(Dataset):
261 example = {} 261 example = {}
262 262
263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) 263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True)
264 example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) 264 example["cprompts"] = unprocessed_example["cprompts"]
265 example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) 265 example["nprompts"] = unprocessed_example["nprompts"]
266 266
267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) 268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])