diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-23 23:02:01 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-23 23:02:01 +0100 | 
| commit | 3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de (patch) | |
| tree | 7b12a26c195e7298bb6cbc993ad0dd0f322fede4 | |
| parent | num_class_images is now class images per train image (diff) | |
| download | textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.gz textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.bz2 textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.zip | |
Better dataset prompt handling
| -rw-r--r-- | data/csv.py | 25 | ||||
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 2 | 
3 files changed, 18 insertions, 11 deletions
| diff --git a/data/csv.py b/data/csv.py index edce2b1..265293b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -15,10 +15,11 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: | 18 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 
| 19 | if dropout != 0: | 19 | if dropout != 0: | 
| 20 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | 20 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | 
| 21 | np.random.shuffle(prompt) | 21 | if shuffle: | 
| 22 | np.random.shuffle(prompt) | ||
| 22 | return ", ".join(prompt) | 23 | return ", ".join(prompt) | 
| 23 | 24 | ||
| 24 | 25 | ||
| @@ -38,8 +39,8 @@ class CSVDataItem(NamedTuple): | |||
| 38 | instance_image_path: Path | 39 | instance_image_path: Path | 
| 39 | class_image_path: Path | 40 | class_image_path: Path | 
| 40 | prompt: list[str] | 41 | prompt: list[str] | 
| 41 | cprompt: str | 42 | cprompt: list[str] | 
| 42 | nprompt: str | 43 | nprompt: list[str] | 
| 43 | 44 | ||
| 44 | 45 | ||
| 45 | class CSVDataModule(): | 46 | class CSVDataModule(): | 
| @@ -104,8 +105,14 @@ class CSVDataModule(): | |||
| 104 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 105 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 
| 105 | expansions | 106 | expansions | 
| 106 | ), | 107 | ), | 
| 107 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | 108 | prompt_to_keywords( | 
| 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 109 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | 
| 110 | expansions | ||
| 111 | ), | ||
| 112 | prompt_to_keywords( | ||
| 113 | prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | ||
| 114 | expansions | ||
| 115 | ), | ||
| 109 | ) | 116 | ) | 
| 110 | for item in data | 117 | for item in data | 
| 111 | ] | 118 | ] | 
| @@ -253,9 +260,9 @@ class CSVDataset(Dataset): | |||
| 253 | 260 | ||
| 254 | example = {} | 261 | example = {} | 
| 255 | 262 | ||
| 256 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) | 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 
| 257 | example["cprompts"] = unprocessed_example["cprompts"] | 264 | example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) | 
| 258 | example["nprompts"] = unprocessed_example["nprompts"] | 265 | example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) | 
| 259 | 266 | ||
| 260 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 267 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 
| 261 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) | 268 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) | 
| diff --git a/train_dreambooth.py b/train_dreambooth.py index 2f913e7..1a79b2b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -687,7 +687,7 @@ def main(): | |||
| 687 | ).to(accelerator.device) | 687 | ).to(accelerator.device) | 
| 688 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 688 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 
| 689 | 689 | ||
| 690 | with torch.autocast("cuda"), torch.inference_mode(): | 690 | with torch.inference_mode(): | 
| 691 | for batch in batched_data: | 691 | for batch in batched_data: | 
| 692 | image_name = [item.class_image_path for item in batch] | 692 | image_name = [item.class_image_path for item in batch] | 
| 693 | prompt = [item.cprompt for item in batch] | 693 | prompt = [item.cprompt for item in batch] | 
| diff --git a/train_ti.py b/train_ti.py index e272b5d..cc208f0 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -621,7 +621,7 @@ def main(): | |||
| 621 | ).to(accelerator.device) | 621 | ).to(accelerator.device) | 
| 622 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 622 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 
| 623 | 623 | ||
| 624 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.inference_mode(): | 
| 625 | for batch in batched_data: | 625 | for batch in batched_data: | 
| 626 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] | 
| 627 | prompt = [item.cprompt for item in batch] | 627 | prompt = [item.cprompt for item in batch] | 
