diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-24 09:35:57 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-24 09:35:57 +0100 |
| commit | dfc51d6d74410acefab86d2938a2b864be603668 (patch) | |
| tree | bed57fcd481bc243324950f93e890294703533f6 /data | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.gz textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.bz2 textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.zip | |
Update
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py index b45ac77..0810c2c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple): | |||
| 39 | instance_image_path: Path | 39 | instance_image_path: Path |
| 40 | class_image_path: Path | 40 | class_image_path: Path |
| 41 | prompt: list[str] | 41 | prompt: list[str] |
| 42 | cprompt: list[str] | 42 | cprompt: str |
| 43 | nprompt: list[str] | 43 | nprompt: str |
| 44 | 44 | ||
| 45 | 45 | ||
| 46 | class CSVDataModule(): | 46 | class CSVDataModule(): |
| @@ -105,14 +105,14 @@ class CSVDataModule(): | |||
| 105 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 105 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 106 | expansions | 106 | expansions |
| 107 | ), | 107 | ), |
| 108 | prompt_to_keywords( | 108 | keywords_to_prompt(prompt_to_keywords( |
| 109 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 109 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 110 | expansions | 110 | expansions |
| 111 | ), | 111 | )), |
| 112 | prompt_to_keywords( | 112 | keywords_to_prompt(prompt_to_keywords( |
| 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 114 | expansions | 114 | expansions |
| 115 | ), | 115 | )), |
| 116 | ) | 116 | ) |
| 117 | for item in data | 117 | for item in data |
| 118 | ] | 118 | ] |
| @@ -261,8 +261,8 @@ class CSVDataset(Dataset): | |||
| 261 | example = {} | 261 | example = {} |
| 262 | 262 | ||
| 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) |
| 264 | example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) | 264 | example["cprompts"] = unprocessed_example["cprompts"] |
| 265 | example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) | 265 | example["nprompts"] = unprocessed_example["nprompts"] |
| 266 | 266 | ||
| 267 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 267 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 268 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) | 268 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
