summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
commit3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de (patch)
tree7b12a26c195e7298bb6cbc993ad0dd0f322fede4
parentnum_class_images is now class images per train image (diff)
downloadtextual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.gz
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.bz2
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.zip
Better dataset prompt handling
-rw-r--r--data/csv.py25
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_ti.py2
3 files changed, 18 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py
index edce2b1..265293b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -15,10 +15,11 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 15 return {"content": prompt} if isinstance(prompt, str) else prompt
16 16
17 17
18def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: 18def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
19 if dropout != 0: 19 if dropout != 0:
20 prompt = [keyword for keyword in prompt if np.random.random() > dropout] 20 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
21 np.random.shuffle(prompt) 21 if shuffle:
22 np.random.shuffle(prompt)
22 return ", ".join(prompt) 23 return ", ".join(prompt)
23 24
24 25
@@ -38,8 +39,8 @@ class CSVDataItem(NamedTuple):
38 instance_image_path: Path 39 instance_image_path: Path
39 class_image_path: Path 40 class_image_path: Path
40 prompt: list[str] 41 prompt: list[str]
41 cprompt: str 42 cprompt: list[str]
42 nprompt: str 43 nprompt: list[str]
43 44
44 45
45class CSVDataModule(): 46class CSVDataModule():
@@ -104,8 +105,14 @@ class CSVDataModule():
104 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
105 expansions 106 expansions
106 ), 107 ),
107 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), 108 prompt_to_keywords(
108 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 109 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")),
110 expansions
111 ),
112 prompt_to_keywords(
113 prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions
115 ),
109 ) 116 )
110 for item in data 117 for item in data
111 ] 118 ]
@@ -253,9 +260,9 @@ class CSVDataset(Dataset):
253 260
254 example = {} 261 example = {}
255 262
256 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) 263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True)
257 example["cprompts"] = unprocessed_example["cprompts"] 264 example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"])
258 example["nprompts"] = unprocessed_example["nprompts"] 265 example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"])
259 266
260 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
261 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) 268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2f913e7..1a79b2b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -687,7 +687,7 @@ def main():
687 ).to(accelerator.device) 687 ).to(accelerator.device)
688 pipeline.set_progress_bar_config(dynamic_ncols=True) 688 pipeline.set_progress_bar_config(dynamic_ncols=True)
689 689
690 with torch.autocast("cuda"), torch.inference_mode(): 690 with torch.inference_mode():
691 for batch in batched_data: 691 for batch in batched_data:
692 image_name = [item.class_image_path for item in batch] 692 image_name = [item.class_image_path for item in batch]
693 prompt = [item.cprompt for item in batch] 693 prompt = [item.cprompt for item in batch]
diff --git a/train_ti.py b/train_ti.py
index e272b5d..cc208f0 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -621,7 +621,7 @@ def main():
621 ).to(accelerator.device) 621 ).to(accelerator.device)
622 pipeline.set_progress_bar_config(dynamic_ncols=True) 622 pipeline.set_progress_bar_config(dynamic_ncols=True)
623 623
624 with torch.autocast("cuda"), torch.inference_mode(): 624 with torch.inference_mode():
625 for batch in batched_data: 625 for batch in batched_data:
626 image_name = [item.class_image_path for item in batch] 626 image_name = [item.class_image_path for item in batch]
627 prompt = [item.cprompt for item in batch] 627 prompt = [item.cprompt for item in batch]