diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index fba5d4b..a6cd065 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -99,14 +99,16 @@ def generate_buckets( | |||
99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
100 | 100 | ||
101 | 101 | ||
102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): | 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): |
103 | prompt_ids = [example["prompt_ids"] for example in examples] | 103 | prompt_ids = [example["prompt_ids"] for example in examples] |
104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
105 | 105 | ||
106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] |
107 | pixel_values = [example["instance_images"] for example in examples] | 107 | pixel_values = [example["instance_images"] for example in examples] |
108 | 108 | ||
109 | if with_prior_preservation: | 109 | if with_guidance: |
110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
111 | elif with_prior_preservation: | ||
110 | input_ids += [example["class_prompt_ids"] for example in examples] | 112 | input_ids += [example["class_prompt_ids"] for example in examples] |
111 | pixel_values += [example["class_images"] for example in examples] | 113 | pixel_values += [example["class_images"] for example in examples] |
112 | 114 | ||
@@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple): | |||
133 | class_image_path: Path | 135 | class_image_path: Path |
134 | prompt: list[str] | 136 | prompt: list[str] |
135 | cprompt: str | 137 | cprompt: str |
136 | nprompt: str | 138 | nprompt: list[str] |
137 | collection: list[str] | 139 | collection: list[str] |
138 | 140 | ||
139 | 141 | ||
@@ -163,6 +165,7 @@ class VlpnDataModule(): | |||
163 | data_file: str, | 165 | data_file: str, |
164 | tokenizer: CLIPTokenizer, | 166 | tokenizer: CLIPTokenizer, |
165 | class_subdir: str = "cls", | 167 | class_subdir: str = "cls", |
168 | with_guidance: bool = False, | ||
166 | num_class_images: int = 1, | 169 | num_class_images: int = 1, |
167 | size: int = 768, | 170 | size: int = 768, |
168 | num_buckets: int = 0, | 171 | num_buckets: int = 0, |
@@ -191,6 +194,7 @@ class VlpnDataModule(): | |||
191 | self.class_root = self.data_root / class_subdir | 194 | self.class_root = self.data_root / class_subdir |
192 | self.class_root.mkdir(parents=True, exist_ok=True) | 195 | self.class_root.mkdir(parents=True, exist_ok=True) |
193 | self.num_class_images = num_class_images | 196 | self.num_class_images = num_class_images |
197 | self.with_guidance = with_guidance | ||
194 | 198 | ||
195 | self.tokenizer = tokenizer | 199 | self.tokenizer = tokenizer |
196 | self.size = size | 200 | self.size = size |
@@ -228,10 +232,10 @@ class VlpnDataModule(): | |||
228 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
229 | expansions | 233 | expansions |
230 | )), | 234 | )), |
231 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_to_keywords( |
232 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 236 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
233 | expansions | 237 | expansions |
234 | )), | 238 | ), |
235 | item["collection"].split(", ") if "collection" in item else [] | 239 | item["collection"].split(", ") if "collection" in item else [] |
236 | ) | 240 | ) |
237 | for item in data | 241 | for item in data |
@@ -279,7 +283,7 @@ class VlpnDataModule(): | |||
279 | if self.seed is not None: | 283 | if self.seed is not None: |
280 | generator = generator.manual_seed(self.seed) | 284 | generator = generator.manual_seed(self.seed) |
281 | 285 | ||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 286 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
283 | 287 | ||
284 | if valid_set_size == 0: | 288 | if valid_set_size == 0: |
285 | data_train, data_val = items, items | 289 | data_train, data_val = items, items |
@@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset): | |||
443 | example = {} | 447 | example = {} |
444 | 448 | ||
445 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 449 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
446 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 450 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) |
447 | 451 | ||
448 | example["instance_prompt_ids"] = self.get_input_ids( | 452 | example["instance_prompt_ids"] = self.get_input_ids( |
449 | keywords_to_prompt(item.prompt, self.dropout, True) | 453 | keywords_to_prompt(item.prompt, self.dropout, True) |
450 | ) | 454 | ) |
455 | example["negative_prompt_ids"] = self.get_input_ids( | ||
456 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
457 | ) | ||
451 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 458 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
452 | 459 | ||
453 | if self.num_class_images != 0: | 460 | if self.num_class_images != 0: |