diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
| commit | 6b8a93f46f053668c8023520225a18445d48d8f1 (patch) | |
| tree | 463c8835a9a90dd9b5586a13e55d6882caa3103a /data | |
| parent | Update (diff) | |
| download | textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2 textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip | |
Update
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index fba5d4b..a6cd065 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -99,14 +99,16 @@ def generate_buckets( | |||
| 99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
| 100 | 100 | ||
| 101 | 101 | ||
| 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): | 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): |
| 103 | prompt_ids = [example["prompt_ids"] for example in examples] | 103 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 105 | 105 | ||
| 106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 107 | pixel_values = [example["instance_images"] for example in examples] | 107 | pixel_values = [example["instance_images"] for example in examples] |
| 108 | 108 | ||
| 109 | if with_prior_preservation: | 109 | if with_guidance: |
| 110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
| 111 | elif with_prior_preservation: | ||
| 110 | input_ids += [example["class_prompt_ids"] for example in examples] | 112 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 111 | pixel_values += [example["class_images"] for example in examples] | 113 | pixel_values += [example["class_images"] for example in examples] |
| 112 | 114 | ||
| @@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple): | |||
| 133 | class_image_path: Path | 135 | class_image_path: Path |
| 134 | prompt: list[str] | 136 | prompt: list[str] |
| 135 | cprompt: str | 137 | cprompt: str |
| 136 | nprompt: str | 138 | nprompt: list[str] |
| 137 | collection: list[str] | 139 | collection: list[str] |
| 138 | 140 | ||
| 139 | 141 | ||
| @@ -163,6 +165,7 @@ class VlpnDataModule(): | |||
| 163 | data_file: str, | 165 | data_file: str, |
| 164 | tokenizer: CLIPTokenizer, | 166 | tokenizer: CLIPTokenizer, |
| 165 | class_subdir: str = "cls", | 167 | class_subdir: str = "cls", |
| 168 | with_guidance: bool = False, | ||
| 166 | num_class_images: int = 1, | 169 | num_class_images: int = 1, |
| 167 | size: int = 768, | 170 | size: int = 768, |
| 168 | num_buckets: int = 0, | 171 | num_buckets: int = 0, |
| @@ -191,6 +194,7 @@ class VlpnDataModule(): | |||
| 191 | self.class_root = self.data_root / class_subdir | 194 | self.class_root = self.data_root / class_subdir |
| 192 | self.class_root.mkdir(parents=True, exist_ok=True) | 195 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 193 | self.num_class_images = num_class_images | 196 | self.num_class_images = num_class_images |
| 197 | self.with_guidance = with_guidance | ||
| 194 | 198 | ||
| 195 | self.tokenizer = tokenizer | 199 | self.tokenizer = tokenizer |
| 196 | self.size = size | 200 | self.size = size |
| @@ -228,10 +232,10 @@ class VlpnDataModule(): | |||
| 228 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 229 | expansions | 233 | expansions |
| 230 | )), | 234 | )), |
| 231 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_to_keywords( |
| 232 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 236 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 233 | expansions | 237 | expansions |
| 234 | )), | 238 | ), |
| 235 | item["collection"].split(", ") if "collection" in item else [] | 239 | item["collection"].split(", ") if "collection" in item else [] |
| 236 | ) | 240 | ) |
| 237 | for item in data | 241 | for item in data |
| @@ -279,7 +283,7 @@ class VlpnDataModule(): | |||
| 279 | if self.seed is not None: | 283 | if self.seed is not None: |
| 280 | generator = generator.manual_seed(self.seed) | 284 | generator = generator.manual_seed(self.seed) |
| 281 | 285 | ||
| 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 286 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
| 283 | 287 | ||
| 284 | if valid_set_size == 0: | 288 | if valid_set_size == 0: |
| 285 | data_train, data_val = items, items | 289 | data_train, data_val = items, items |
| @@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset): | |||
| 443 | example = {} | 447 | example = {} |
| 444 | 448 | ||
| 445 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 449 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
| 446 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 450 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) |
| 447 | 451 | ||
| 448 | example["instance_prompt_ids"] = self.get_input_ids( | 452 | example["instance_prompt_ids"] = self.get_input_ids( |
| 449 | keywords_to_prompt(item.prompt, self.dropout, True) | 453 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 450 | ) | 454 | ) |
| 455 | example["negative_prompt_ids"] = self.get_input_ids( | ||
| 456 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
| 457 | ) | ||
| 451 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 458 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 452 | 459 | ||
| 453 | if self.num_class_images != 0: | 460 | if self.num_class_images != 0: |
