summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py
index fba5d4b..a6cd065 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -99,14 +99,16 @@ def generate_buckets(
99 return buckets, bucket_items, bucket_assignments 99 return buckets, bucket_items, bucket_assignments
100 100
101 101
102def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): 102def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples):
103 prompt_ids = [example["prompt_ids"] for example in examples] 103 prompt_ids = [example["prompt_ids"] for example in examples]
104 nprompt_ids = [example["nprompt_ids"] for example in examples] 104 nprompt_ids = [example["nprompt_ids"] for example in examples]
105 105
106 input_ids = [example["instance_prompt_ids"] for example in examples] 106 input_ids = [example["instance_prompt_ids"] for example in examples]
107 pixel_values = [example["instance_images"] for example in examples] 107 pixel_values = [example["instance_images"] for example in examples]
108 108
109 if with_prior_preservation: 109 if with_guidance:
110 input_ids += [example["negative_prompt_ids"] for example in examples]
111 elif with_prior_preservation:
110 input_ids += [example["class_prompt_ids"] for example in examples] 112 input_ids += [example["class_prompt_ids"] for example in examples]
111 pixel_values += [example["class_images"] for example in examples] 113 pixel_values += [example["class_images"] for example in examples]
112 114
@@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple):
133 class_image_path: Path 135 class_image_path: Path
134 prompt: list[str] 136 prompt: list[str]
135 cprompt: str 137 cprompt: str
136 nprompt: str 138 nprompt: list[str]
137 collection: list[str] 139 collection: list[str]
138 140
139 141
@@ -163,6 +165,7 @@ class VlpnDataModule():
163 data_file: str, 165 data_file: str,
164 tokenizer: CLIPTokenizer, 166 tokenizer: CLIPTokenizer,
165 class_subdir: str = "cls", 167 class_subdir: str = "cls",
168 with_guidance: bool = False,
166 num_class_images: int = 1, 169 num_class_images: int = 1,
167 size: int = 768, 170 size: int = 768,
168 num_buckets: int = 0, 171 num_buckets: int = 0,
@@ -191,6 +194,7 @@ class VlpnDataModule():
191 self.class_root = self.data_root / class_subdir 194 self.class_root = self.data_root / class_subdir
192 self.class_root.mkdir(parents=True, exist_ok=True) 195 self.class_root.mkdir(parents=True, exist_ok=True)
193 self.num_class_images = num_class_images 196 self.num_class_images = num_class_images
197 self.with_guidance = with_guidance
194 198
195 self.tokenizer = tokenizer 199 self.tokenizer = tokenizer
196 self.size = size 200 self.size = size
@@ -228,10 +232,10 @@ class VlpnDataModule():
228 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 232 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
229 expansions 233 expansions
230 )), 234 )),
231 keywords_to_prompt(prompt_to_keywords( 235 prompt_to_keywords(
232 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 236 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
233 expansions 237 expansions
234 )), 238 ),
235 item["collection"].split(", ") if "collection" in item else [] 239 item["collection"].split(", ") if "collection" in item else []
236 ) 240 )
237 for item in data 241 for item in data
@@ -279,7 +283,7 @@ class VlpnDataModule():
279 if self.seed is not None: 283 if self.seed is not None:
280 generator = generator.manual_seed(self.seed) 284 generator = generator.manual_seed(self.seed)
281 285
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 286 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0)
283 287
284 if valid_set_size == 0: 288 if valid_set_size == 0:
285 data_train, data_val = items, items 289 data_train, data_val = items, items
@@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset):
443 example = {} 447 example = {}
444 448
445 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) 449 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
446 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 450 example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt))
447 451
448 example["instance_prompt_ids"] = self.get_input_ids( 452 example["instance_prompt_ids"] = self.get_input_ids(
449 keywords_to_prompt(item.prompt, self.dropout, True) 453 keywords_to_prompt(item.prompt, self.dropout, True)
450 ) 454 )
455 example["negative_prompt_ids"] = self.get_input_ids(
456 keywords_to_prompt(item.nprompt, self.dropout, True)
457 )
451 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 458 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
452 459
453 if self.num_class_images != 0: 460 if self.num_class_images != 0: