From 3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 23:02:01 +0100 Subject: Better dataset prompt handling --- data/csv.py | 25 ++++++++++++++++--------- train_dreambooth.py | 2 +- train_ti.py | 2 +- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/data/csv.py b/data/csv.py index edce2b1..265293b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -15,10 +15,11 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: +def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: if dropout != 0: prompt = [keyword for keyword in prompt if np.random.random() > dropout] - np.random.shuffle(prompt) + if shuffle: + np.random.shuffle(prompt) return ", ".join(prompt) @@ -38,8 +39,8 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] - cprompt: str - nprompt: str + cprompt: list[str] + nprompt: list[str] class CSVDataModule(): @@ -104,8 +105,14 @@ class CSVDataModule(): prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), - cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), - nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), + prompt_to_keywords( + cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), + expansions + ), + prompt_to_keywords( + prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), + expansions + ), ) for item in data ] @@ -253,9 +260,9 @@ class CSVDataset(Dataset): example = {} - example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) - example["cprompts"] = unprocessed_example["cprompts"] - example["nprompts"] = unprocessed_example["nprompts"] + example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) + example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) + example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) diff --git a/train_dreambooth.py b/train_dreambooth.py index 2f913e7..1a79b2b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -687,7 +687,7 @@ def main(): ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] diff --git a/train_ti.py b/train_ti.py index e272b5d..cc208f0 100644 --- a/train_ti.py +++ b/train_ti.py @@ -621,7 +621,7 @@ def main(): ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] -- cgit v1.2.3-54-g00ecf