diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 47 |
1 files changed, 42 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | import copy | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
| 7 | 7 | ||
| @@ -99,6 +99,41 @@ def generate_buckets( | |||
| 99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
| 100 | 100 | ||
| 101 | 101 | ||
| 102 | def collate_fn( | ||
| 103 | num_class_images: int, | ||
| 104 | weight_dtype: torch.dtype, | ||
| 105 | prompt_processor: PromptProcessor, | ||
| 106 | examples | ||
| 107 | ): | ||
| 108 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
| 109 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
| 110 | |||
| 111 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 112 | pixel_values = [example["instance_images"] for example in examples] | ||
| 113 | |||
| 114 | # concat class and instance examples for prior preservation | ||
| 115 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 116 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 117 | pixel_values += [example["class_images"] for example in examples] | ||
| 118 | |||
| 119 | pixel_values = torch.stack(pixel_values) | ||
| 120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
| 121 | |||
| 122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 124 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
| 125 | |||
| 126 | batch = { | ||
| 127 | "prompt_ids": prompts.input_ids, | ||
| 128 | "nprompt_ids": nprompts.input_ids, | ||
| 129 | "input_ids": inputs.input_ids, | ||
| 130 | "pixel_values": pixel_values, | ||
| 131 | "attention_mask": inputs.attention_mask, | ||
| 132 | } | ||
| 133 | |||
| 134 | return batch | ||
| 135 | |||
| 136 | |||
| 102 | class VlpnDataItem(NamedTuple): | 137 | class VlpnDataItem(NamedTuple): |
| 103 | instance_image_path: Path | 138 | instance_image_path: Path |
| 104 | class_image_path: Path | 139 | class_image_path: Path |
| @@ -129,7 +164,7 @@ class VlpnDataModule(): | |||
| 129 | valid_set_repeat: int = 1, | 164 | valid_set_repeat: int = 1, |
| 130 | seed: Optional[int] = None, | 165 | seed: Optional[int] = None, |
| 131 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 166 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 132 | collate_fn=None, | 167 | dtype: torch.dtype = torch.float32, |
| 133 | num_workers: int = 0 | 168 | num_workers: int = 0 |
| 134 | ): | 169 | ): |
| 135 | super().__init__() | 170 | super().__init__() |
| @@ -158,9 +193,9 @@ class VlpnDataModule(): | |||
| 158 | self.valid_set_repeat = valid_set_repeat | 193 | self.valid_set_repeat = valid_set_repeat |
| 159 | self.seed = seed | 194 | self.seed = seed |
| 160 | self.filter = filter | 195 | self.filter = filter |
| 161 | self.collate_fn = collate_fn | ||
| 162 | self.num_workers = num_workers | 196 | self.num_workers = num_workers |
| 163 | self.batch_size = batch_size | 197 | self.batch_size = batch_size |
| 198 | self.dtype = dtype | ||
| 164 | 199 | ||
| 165 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 200 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 166 | image = template["image"] if "image" in template else "{}" | 201 | image = template["image"] if "image" in template else "{}" |
| @@ -254,14 +289,16 @@ class VlpnDataModule(): | |||
| 254 | size=self.size, interpolation=self.interpolation, | 289 | size=self.size, interpolation=self.interpolation, |
| 255 | ) | 290 | ) |
| 256 | 291 | ||
| 292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | ||
| 293 | |||
| 257 | self.train_dataloader = DataLoader( | 294 | self.train_dataloader = DataLoader( |
| 258 | train_dataset, | 295 | train_dataset, |
| 259 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 296 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
| 260 | ) | 297 | ) |
| 261 | 298 | ||
| 262 | self.val_dataloader = DataLoader( | 299 | self.val_dataloader = DataLoader( |
| 263 | val_dataset, | 300 | val_dataset, |
| 264 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 301 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers |
| 265 | ) | 302 | ) |
| 266 | 303 | ||
| 267 | 304 | ||
