summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/data/csv.py b/data/csv.py
index f5fc8e6..a3fef30 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -9,9 +9,10 @@ from PIL import Image
9 9
10from torch.utils.data import IterableDataset, DataLoader, random_split 10from torch.utils.data import IterableDataset, DataLoader, random_split
11from torchvision import transforms 11from torchvision import transforms
12from transformers import CLIPTokenizer
12 13
13from data.keywords import prompt_to_keywords, keywords_to_prompt 14from data.keywords import prompt_to_keywords, keywords_to_prompt
14from models.clip.prompt import PromptProcessor 15from models.clip.util import unify_input_ids
15 16
16 17
17image_cache: dict[str, Image.Image] = {} 18image_cache: dict[str, Image.Image] = {}
@@ -102,7 +103,7 @@ def generate_buckets(
102def collate_fn( 103def collate_fn(
103 num_class_images: int, 104 num_class_images: int,
104 weight_dtype: torch.dtype, 105 weight_dtype: torch.dtype,
105 prompt_processor: PromptProcessor, 106 tokenizer: CLIPTokenizer,
106 examples 107 examples
107): 108):
108 prompt_ids = [example["prompt_ids"] for example in examples] 109 prompt_ids = [example["prompt_ids"] for example in examples]
@@ -119,9 +120,9 @@ def collate_fn(
119 pixel_values = torch.stack(pixel_values) 120 pixel_values = torch.stack(pixel_values)
120 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 121 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
121 122
122 prompts = prompt_processor.unify_input_ids(prompt_ids) 123 prompts = unify_input_ids(tokenizer, prompt_ids)
123 nprompts = prompt_processor.unify_input_ids(nprompt_ids) 124 nprompts = unify_input_ids(tokenizer, nprompt_ids)
124 inputs = prompt_processor.unify_input_ids(input_ids) 125 inputs = unify_input_ids(tokenizer, input_ids)
125 126
126 batch = { 127 batch = {
127 "prompt_ids": prompts.input_ids, 128 "prompt_ids": prompts.input_ids,
@@ -148,7 +149,7 @@ class VlpnDataModule():
148 self, 149 self,
149 batch_size: int, 150 batch_size: int,
150 data_file: str, 151 data_file: str,
151 prompt_processor: PromptProcessor, 152 tokenizer: CLIPTokenizer,
152 class_subdir: str = "cls", 153 class_subdir: str = "cls",
153 num_class_images: int = 1, 154 num_class_images: int = 1,
154 size: int = 768, 155 size: int = 768,
@@ -179,7 +180,7 @@ class VlpnDataModule():
179 self.class_root.mkdir(parents=True, exist_ok=True) 180 self.class_root.mkdir(parents=True, exist_ok=True)
180 self.num_class_images = num_class_images 181 self.num_class_images = num_class_images
181 182
182 self.prompt_processor = prompt_processor 183 self.tokenizer = tokenizer
183 self.size = size 184 self.size = size
184 self.num_buckets = num_buckets 185 self.num_buckets = num_buckets
185 self.bucket_step_size = bucket_step_size 186 self.bucket_step_size = bucket_step_size
@@ -272,7 +273,7 @@ class VlpnDataModule():
272 self.data_val = self.pad_items(data_val) 273 self.data_val = self.pad_items(data_val)
273 274
274 train_dataset = VlpnDataset( 275 train_dataset = VlpnDataset(
275 self.data_train, self.prompt_processor, 276 self.data_train, self.tokenizer,
276 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 277 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
277 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 278 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
278 batch_size=self.batch_size, generator=generator, 279 batch_size=self.batch_size, generator=generator,
@@ -281,7 +282,7 @@ class VlpnDataModule():
281 ) 282 )
282 283
283 val_dataset = VlpnDataset( 284 val_dataset = VlpnDataset(
284 self.data_val, self.prompt_processor, 285 self.data_val, self.tokenizer,
285 num_buckets=self.num_buckets, progressive_buckets=True, 286 num_buckets=self.num_buckets, progressive_buckets=True,
286 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 287 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
287 repeat=self.valid_set_repeat, 288 repeat=self.valid_set_repeat,
@@ -289,7 +290,7 @@ class VlpnDataModule():
289 size=self.size, interpolation=self.interpolation, 290 size=self.size, interpolation=self.interpolation,
290 ) 291 )
291 292
292 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) 293 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer)
293 294
294 self.train_dataloader = DataLoader( 295 self.train_dataloader = DataLoader(
295 train_dataset, 296 train_dataset,
@@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset):
306 def __init__( 307 def __init__(
307 self, 308 self,
308 items: list[VlpnDataItem], 309 items: list[VlpnDataItem],
309 prompt_processor: PromptProcessor, 310 tokenizer: CLIPTokenizer,
310 num_buckets: int = 1, 311 num_buckets: int = 1,
311 bucket_step_size: int = 64, 312 bucket_step_size: int = 64,
312 bucket_max_pixels: Optional[int] = None, 313 bucket_max_pixels: Optional[int] = None,
@@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset):
323 self.items = items * repeat 324 self.items = items * repeat
324 self.batch_size = batch_size 325 self.batch_size = batch_size
325 326
326 self.prompt_processor = prompt_processor 327 self.tokenizer = tokenizer
327 self.num_class_images = num_class_images 328 self.num_class_images = num_class_images
328 self.size = size 329 self.size = size
329 self.dropout = dropout 330 self.dropout = dropout
@@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset):
344 345
345 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 346 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
346 347
348 def get_input_ids(self, text: str):
349 return self.tokenizer(text, padding="do_not_pad").input_ids
350
347 def __len__(self): 351 def __len__(self):
348 return self.length_ 352 return self.length_
349 353
@@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset):
404 408
405 example = {} 409 example = {}
406 410
407 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) 411 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
408 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) 412 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
409 413
410 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 414 example["instance_prompt_ids"] = self.get_input_ids(
411 keywords_to_prompt(item.prompt, self.dropout, True) 415 keywords_to_prompt(item.prompt, self.dropout, True)
412 ) 416 )
413 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 417 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
414 418
415 if self.num_class_images != 0: 419 if self.num_class_images != 0:
416 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) 420 example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
417 example["class_images"] = image_transforms(get_image(item.class_image_path)) 421 example["class_images"] = image_transforms(get_image(item.class_image_path))
418 422
419 batch.append(example) 423 batch.append(example)