diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -9,9 +9,10 @@ from PIL import Image | |||
9 | 9 | ||
10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 10 | from torch.utils.data import IterableDataset, DataLoader, random_split |
11 | from torchvision import transforms | 11 | from torchvision import transforms |
12 | from transformers import CLIPTokenizer | ||
12 | 13 | ||
13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
14 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.util import unify_input_ids |
15 | 16 | ||
16 | 17 | ||
17 | image_cache: dict[str, Image.Image] = {} | 18 | image_cache: dict[str, Image.Image] = {} |
@@ -102,7 +103,7 @@ def generate_buckets( | |||
102 | def collate_fn( | 103 | def collate_fn( |
103 | num_class_images: int, | 104 | num_class_images: int, |
104 | weight_dtype: torch.dtype, | 105 | weight_dtype: torch.dtype, |
105 | prompt_processor: PromptProcessor, | 106 | tokenizer: CLIPTokenizer, |
106 | examples | 107 | examples |
107 | ): | 108 | ): |
108 | prompt_ids = [example["prompt_ids"] for example in examples] | 109 | prompt_ids = [example["prompt_ids"] for example in examples] |
@@ -119,9 +120,9 @@ def collate_fn( | |||
119 | pixel_values = torch.stack(pixel_values) | 120 | pixel_values = torch.stack(pixel_values) |
120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 121 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
121 | 122 | ||
122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | 123 | prompts = unify_input_ids(tokenizer, prompt_ids) |
123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | 124 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
124 | inputs = prompt_processor.unify_input_ids(input_ids) | 125 | inputs = unify_input_ids(tokenizer, input_ids) |
125 | 126 | ||
126 | batch = { | 127 | batch = { |
127 | "prompt_ids": prompts.input_ids, | 128 | "prompt_ids": prompts.input_ids, |
@@ -148,7 +149,7 @@ class VlpnDataModule(): | |||
148 | self, | 149 | self, |
149 | batch_size: int, | 150 | batch_size: int, |
150 | data_file: str, | 151 | data_file: str, |
151 | prompt_processor: PromptProcessor, | 152 | tokenizer: CLIPTokenizer, |
152 | class_subdir: str = "cls", | 153 | class_subdir: str = "cls", |
153 | num_class_images: int = 1, | 154 | num_class_images: int = 1, |
154 | size: int = 768, | 155 | size: int = 768, |
@@ -179,7 +180,7 @@ class VlpnDataModule(): | |||
179 | self.class_root.mkdir(parents=True, exist_ok=True) | 180 | self.class_root.mkdir(parents=True, exist_ok=True) |
180 | self.num_class_images = num_class_images | 181 | self.num_class_images = num_class_images |
181 | 182 | ||
182 | self.prompt_processor = prompt_processor | 183 | self.tokenizer = tokenizer |
183 | self.size = size | 184 | self.size = size |
184 | self.num_buckets = num_buckets | 185 | self.num_buckets = num_buckets |
185 | self.bucket_step_size = bucket_step_size | 186 | self.bucket_step_size = bucket_step_size |
@@ -272,7 +273,7 @@ class VlpnDataModule(): | |||
272 | self.data_val = self.pad_items(data_val) | 273 | self.data_val = self.pad_items(data_val) |
273 | 274 | ||
274 | train_dataset = VlpnDataset( | 275 | train_dataset = VlpnDataset( |
275 | self.data_train, self.prompt_processor, | 276 | self.data_train, self.tokenizer, |
276 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 277 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
277 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 278 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
278 | batch_size=self.batch_size, generator=generator, | 279 | batch_size=self.batch_size, generator=generator, |
@@ -281,7 +282,7 @@ class VlpnDataModule(): | |||
281 | ) | 282 | ) |
282 | 283 | ||
283 | val_dataset = VlpnDataset( | 284 | val_dataset = VlpnDataset( |
284 | self.data_val, self.prompt_processor, | 285 | self.data_val, self.tokenizer, |
285 | num_buckets=self.num_buckets, progressive_buckets=True, | 286 | num_buckets=self.num_buckets, progressive_buckets=True, |
286 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 287 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
287 | repeat=self.valid_set_repeat, | 288 | repeat=self.valid_set_repeat, |
@@ -289,7 +290,7 @@ class VlpnDataModule(): | |||
289 | size=self.size, interpolation=self.interpolation, | 290 | size=self.size, interpolation=self.interpolation, |
290 | ) | 291 | ) |
291 | 292 | ||
292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) |
293 | 294 | ||
294 | self.train_dataloader = DataLoader( | 295 | self.train_dataloader = DataLoader( |
295 | train_dataset, | 296 | train_dataset, |
@@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): | |||
306 | def __init__( | 307 | def __init__( |
307 | self, | 308 | self, |
308 | items: list[VlpnDataItem], | 309 | items: list[VlpnDataItem], |
309 | prompt_processor: PromptProcessor, | 310 | tokenizer: CLIPTokenizer, |
310 | num_buckets: int = 1, | 311 | num_buckets: int = 1, |
311 | bucket_step_size: int = 64, | 312 | bucket_step_size: int = 64, |
312 | bucket_max_pixels: Optional[int] = None, | 313 | bucket_max_pixels: Optional[int] = None, |
@@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): | |||
323 | self.items = items * repeat | 324 | self.items = items * repeat |
324 | self.batch_size = batch_size | 325 | self.batch_size = batch_size |
325 | 326 | ||
326 | self.prompt_processor = prompt_processor | 327 | self.tokenizer = tokenizer |
327 | self.num_class_images = num_class_images | 328 | self.num_class_images = num_class_images |
328 | self.size = size | 329 | self.size = size |
329 | self.dropout = dropout | 330 | self.dropout = dropout |
@@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): | |||
344 | 345 | ||
345 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 346 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
346 | 347 | ||
348 | def get_input_ids(self, text: str): | ||
349 | return self.tokenizer(text, padding="do_not_pad").input_ids | ||
350 | |||
347 | def __len__(self): | 351 | def __len__(self): |
348 | return self.length_ | 352 | return self.length_ |
349 | 353 | ||
@@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): | |||
404 | 408 | ||
405 | example = {} | 409 | example = {} |
406 | 410 | ||
407 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | 411 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
408 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | 412 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
409 | 413 | ||
410 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 414 | example["instance_prompt_ids"] = self.get_input_ids( |
411 | keywords_to_prompt(item.prompt, self.dropout, True) | 415 | keywords_to_prompt(item.prompt, self.dropout, True) |
412 | ) | 416 | ) |
413 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 417 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
414 | 418 | ||
415 | if self.num_class_images != 0: | 419 | if self.num_class_images != 0: |
416 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | 420 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
417 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 421 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
418 | 422 | ||
419 | batch.append(example) | 423 | batch.append(example) |