From 6b8a93f46f053668c8023520225a18445d48d8f1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 25 Mar 2023 16:34:48 +0100 Subject: Update --- data/csv.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index fba5d4b..a6cd065 100644 --- a/data/csv.py +++ b/data/csv.py @@ -99,14 +99,16 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): +def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - if with_prior_preservation: + if with_guidance: + input_ids += [example["negative_prompt_ids"] for example in examples] + elif with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple): class_image_path: Path prompt: list[str] cprompt: str - nprompt: str + nprompt: list[str] collection: list[str] @@ -163,6 +165,7 @@ class VlpnDataModule(): data_file: str, tokenizer: CLIPTokenizer, class_subdir: str = "cls", + with_guidance: bool = False, num_class_images: int = 1, size: int = 768, num_buckets: int = 0, @@ -191,6 +194,7 @@ class VlpnDataModule(): self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images + self.with_guidance = with_guidance self.tokenizer = tokenizer self.size = size @@ -228,10 +232,10 @@ class VlpnDataModule(): cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions )), - keywords_to_prompt(prompt_to_keywords( + prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions - )), + ), item["collection"].split(", ") if "collection" in item else [] ) for item in data @@ -279,7 +283,7 @@ class VlpnDataModule(): if self.seed is not None: generator = generator.manual_seed(self.seed) - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items @@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset): example = {} example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.get_input_ids(item.nprompt) + example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) example["instance_prompt_ids"] = self.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) + example["negative_prompt_ids"] = self.get_input_ids( + keywords_to_prompt(item.nprompt, self.dropout, True) + ) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: -- cgit v1.2.3-54-g00ecf