summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py
index 9ad7dd6..f5fc8e6 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,7 +1,7 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import copy 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 6from typing import NamedTuple, Optional, Union, Callable
7 7
@@ -99,6 +99,41 @@ def generate_buckets(
99 return buckets, bucket_items, bucket_assignments 99 return buckets, bucket_items, bucket_assignments
100 100
101 101
102def collate_fn(
103 num_class_images: int,
104 weight_dtype: torch.dtype,
105 prompt_processor: PromptProcessor,
106 examples
107):
108 prompt_ids = [example["prompt_ids"] for example in examples]
109 nprompt_ids = [example["nprompt_ids"] for example in examples]
110
111 input_ids = [example["instance_prompt_ids"] for example in examples]
112 pixel_values = [example["instance_images"] for example in examples]
113
114 # concat class and instance examples for prior preservation
115 if num_class_images != 0 and "class_prompt_ids" in examples[0]:
116 input_ids += [example["class_prompt_ids"] for example in examples]
117 pixel_values += [example["class_images"] for example in examples]
118
119 pixel_values = torch.stack(pixel_values)
120 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
121
122 prompts = prompt_processor.unify_input_ids(prompt_ids)
123 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
124 inputs = prompt_processor.unify_input_ids(input_ids)
125
126 batch = {
127 "prompt_ids": prompts.input_ids,
128 "nprompt_ids": nprompts.input_ids,
129 "input_ids": inputs.input_ids,
130 "pixel_values": pixel_values,
131 "attention_mask": inputs.attention_mask,
132 }
133
134 return batch
135
136
102class VlpnDataItem(NamedTuple): 137class VlpnDataItem(NamedTuple):
103 instance_image_path: Path 138 instance_image_path: Path
104 class_image_path: Path 139 class_image_path: Path
@@ -129,7 +164,7 @@ class VlpnDataModule():
129 valid_set_repeat: int = 1, 164 valid_set_repeat: int = 1,
130 seed: Optional[int] = None, 165 seed: Optional[int] = None,
131 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 166 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
132 collate_fn=None, 167 dtype: torch.dtype = torch.float32,
133 num_workers: int = 0 168 num_workers: int = 0
134 ): 169 ):
135 super().__init__() 170 super().__init__()
@@ -158,9 +193,9 @@ class VlpnDataModule():
158 self.valid_set_repeat = valid_set_repeat 193 self.valid_set_repeat = valid_set_repeat
159 self.seed = seed 194 self.seed = seed
160 self.filter = filter 195 self.filter = filter
161 self.collate_fn = collate_fn
162 self.num_workers = num_workers 196 self.num_workers = num_workers
163 self.batch_size = batch_size 197 self.batch_size = batch_size
198 self.dtype = dtype
164 199
165 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 200 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
166 image = template["image"] if "image" in template else "{}" 201 image = template["image"] if "image" in template else "{}"
@@ -254,14 +289,16 @@ class VlpnDataModule():
254 size=self.size, interpolation=self.interpolation, 289 size=self.size, interpolation=self.interpolation,
255 ) 290 )
256 291
292 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor)
293
257 self.train_dataloader = DataLoader( 294 self.train_dataloader = DataLoader(
258 train_dataset, 295 train_dataset,
259 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 296 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers
260 ) 297 )
261 298
262 self.val_dataloader = DataLoader( 299 self.val_dataloader = DataLoader(
263 val_dataset, 300 val_dataset,
264 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 301 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers
265 ) 302 )
266 303
267 304