diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 47 |
1 files changed, 42 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,7 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import copy | 4 | from functools import partial |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
7 | 7 | ||
@@ -99,6 +99,41 @@ def generate_buckets( | |||
99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
100 | 100 | ||
101 | 101 | ||
102 | def collate_fn( | ||
103 | num_class_images: int, | ||
104 | weight_dtype: torch.dtype, | ||
105 | prompt_processor: PromptProcessor, | ||
106 | examples | ||
107 | ): | ||
108 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
109 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
110 | |||
111 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
112 | pixel_values = [example["instance_images"] for example in examples] | ||
113 | |||
114 | # concat class and instance examples for prior preservation | ||
115 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
116 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
117 | pixel_values += [example["class_images"] for example in examples] | ||
118 | |||
119 | pixel_values = torch.stack(pixel_values) | ||
120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
121 | |||
122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
124 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
125 | |||
126 | batch = { | ||
127 | "prompt_ids": prompts.input_ids, | ||
128 | "nprompt_ids": nprompts.input_ids, | ||
129 | "input_ids": inputs.input_ids, | ||
130 | "pixel_values": pixel_values, | ||
131 | "attention_mask": inputs.attention_mask, | ||
132 | } | ||
133 | |||
134 | return batch | ||
135 | |||
136 | |||
102 | class VlpnDataItem(NamedTuple): | 137 | class VlpnDataItem(NamedTuple): |
103 | instance_image_path: Path | 138 | instance_image_path: Path |
104 | class_image_path: Path | 139 | class_image_path: Path |
@@ -129,7 +164,7 @@ class VlpnDataModule(): | |||
129 | valid_set_repeat: int = 1, | 164 | valid_set_repeat: int = 1, |
130 | seed: Optional[int] = None, | 165 | seed: Optional[int] = None, |
131 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 166 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
132 | collate_fn=None, | 167 | dtype: torch.dtype = torch.float32, |
133 | num_workers: int = 0 | 168 | num_workers: int = 0 |
134 | ): | 169 | ): |
135 | super().__init__() | 170 | super().__init__() |
@@ -158,9 +193,9 @@ class VlpnDataModule(): | |||
158 | self.valid_set_repeat = valid_set_repeat | 193 | self.valid_set_repeat = valid_set_repeat |
159 | self.seed = seed | 194 | self.seed = seed |
160 | self.filter = filter | 195 | self.filter = filter |
161 | self.collate_fn = collate_fn | ||
162 | self.num_workers = num_workers | 196 | self.num_workers = num_workers |
163 | self.batch_size = batch_size | 197 | self.batch_size = batch_size |
198 | self.dtype = dtype | ||
164 | 199 | ||
165 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 200 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
166 | image = template["image"] if "image" in template else "{}" | 201 | image = template["image"] if "image" in template else "{}" |
@@ -254,14 +289,16 @@ class VlpnDataModule(): | |||
254 | size=self.size, interpolation=self.interpolation, | 289 | size=self.size, interpolation=self.interpolation, |
255 | ) | 290 | ) |
256 | 291 | ||
292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | ||
293 | |||
257 | self.train_dataloader = DataLoader( | 294 | self.train_dataloader = DataLoader( |
258 | train_dataset, | 295 | train_dataset, |
259 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 296 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
260 | ) | 297 | ) |
261 | 298 | ||
262 | self.val_dataloader = DataLoader( | 299 | self.val_dataloader = DataLoader( |
263 | val_dataset, | 300 | val_dataset, |
264 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 301 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
265 | ) | 302 | ) |
266 | 303 | ||
267 | 304 | ||