diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -9,9 +9,10 @@ from PIL import Image | |||
| 9 | 9 | ||
| 10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 10 | from torch.utils.data import IterableDataset, DataLoader, random_split |
| 11 | from torchvision import transforms | 11 | from torchvision import transforms |
| 12 | from transformers import CLIPTokenizer | ||
| 12 | 13 | ||
| 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
| 14 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.util import unify_input_ids |
| 15 | 16 | ||
| 16 | 17 | ||
| 17 | image_cache: dict[str, Image.Image] = {} | 18 | image_cache: dict[str, Image.Image] = {} |
| @@ -102,7 +103,7 @@ def generate_buckets( | |||
| 102 | def collate_fn( | 103 | def collate_fn( |
| 103 | num_class_images: int, | 104 | num_class_images: int, |
| 104 | weight_dtype: torch.dtype, | 105 | weight_dtype: torch.dtype, |
| 105 | prompt_processor: PromptProcessor, | 106 | tokenizer: CLIPTokenizer, |
| 106 | examples | 107 | examples |
| 107 | ): | 108 | ): |
| 108 | prompt_ids = [example["prompt_ids"] for example in examples] | 109 | prompt_ids = [example["prompt_ids"] for example in examples] |
| @@ -119,9 +120,9 @@ def collate_fn( | |||
| 119 | pixel_values = torch.stack(pixel_values) | 120 | pixel_values = torch.stack(pixel_values) |
| 120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 121 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 121 | 122 | ||
| 122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | 123 | prompts = unify_input_ids(tokenizer, prompt_ids) |
| 123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | 124 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
| 124 | inputs = prompt_processor.unify_input_ids(input_ids) | 125 | inputs = unify_input_ids(tokenizer, input_ids) |
| 125 | 126 | ||
| 126 | batch = { | 127 | batch = { |
| 127 | "prompt_ids": prompts.input_ids, | 128 | "prompt_ids": prompts.input_ids, |
| @@ -148,7 +149,7 @@ class VlpnDataModule(): | |||
| 148 | self, | 149 | self, |
| 149 | batch_size: int, | 150 | batch_size: int, |
| 150 | data_file: str, | 151 | data_file: str, |
| 151 | prompt_processor: PromptProcessor, | 152 | tokenizer: CLIPTokenizer, |
| 152 | class_subdir: str = "cls", | 153 | class_subdir: str = "cls", |
| 153 | num_class_images: int = 1, | 154 | num_class_images: int = 1, |
| 154 | size: int = 768, | 155 | size: int = 768, |
| @@ -179,7 +180,7 @@ class VlpnDataModule(): | |||
| 179 | self.class_root.mkdir(parents=True, exist_ok=True) | 180 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 180 | self.num_class_images = num_class_images | 181 | self.num_class_images = num_class_images |
| 181 | 182 | ||
| 182 | self.prompt_processor = prompt_processor | 183 | self.tokenizer = tokenizer |
| 183 | self.size = size | 184 | self.size = size |
| 184 | self.num_buckets = num_buckets | 185 | self.num_buckets = num_buckets |
| 185 | self.bucket_step_size = bucket_step_size | 186 | self.bucket_step_size = bucket_step_size |
| @@ -272,7 +273,7 @@ class VlpnDataModule(): | |||
| 272 | self.data_val = self.pad_items(data_val) | 273 | self.data_val = self.pad_items(data_val) |
| 273 | 274 | ||
| 274 | train_dataset = VlpnDataset( | 275 | train_dataset = VlpnDataset( |
| 275 | self.data_train, self.prompt_processor, | 276 | self.data_train, self.tokenizer, |
| 276 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 277 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 277 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 278 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 278 | batch_size=self.batch_size, generator=generator, | 279 | batch_size=self.batch_size, generator=generator, |
| @@ -281,7 +282,7 @@ class VlpnDataModule(): | |||
| 281 | ) | 282 | ) |
| 282 | 283 | ||
| 283 | val_dataset = VlpnDataset( | 284 | val_dataset = VlpnDataset( |
| 284 | self.data_val, self.prompt_processor, | 285 | self.data_val, self.tokenizer, |
| 285 | num_buckets=self.num_buckets, progressive_buckets=True, | 286 | num_buckets=self.num_buckets, progressive_buckets=True, |
| 286 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 287 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 287 | repeat=self.valid_set_repeat, | 288 | repeat=self.valid_set_repeat, |
| @@ -289,7 +290,7 @@ class VlpnDataModule(): | |||
| 289 | size=self.size, interpolation=self.interpolation, | 290 | size=self.size, interpolation=self.interpolation, |
| 290 | ) | 291 | ) |
| 291 | 292 | ||
| 292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) |
| 293 | 294 | ||
| 294 | self.train_dataloader = DataLoader( | 295 | self.train_dataloader = DataLoader( |
| 295 | train_dataset, | 296 | train_dataset, |
| @@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): | |||
| 306 | def __init__( | 307 | def __init__( |
| 307 | self, | 308 | self, |
| 308 | items: list[VlpnDataItem], | 309 | items: list[VlpnDataItem], |
| 309 | prompt_processor: PromptProcessor, | 310 | tokenizer: CLIPTokenizer, |
| 310 | num_buckets: int = 1, | 311 | num_buckets: int = 1, |
| 311 | bucket_step_size: int = 64, | 312 | bucket_step_size: int = 64, |
| 312 | bucket_max_pixels: Optional[int] = None, | 313 | bucket_max_pixels: Optional[int] = None, |
| @@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): | |||
| 323 | self.items = items * repeat | 324 | self.items = items * repeat |
| 324 | self.batch_size = batch_size | 325 | self.batch_size = batch_size |
| 325 | 326 | ||
| 326 | self.prompt_processor = prompt_processor | 327 | self.tokenizer = tokenizer |
| 327 | self.num_class_images = num_class_images | 328 | self.num_class_images = num_class_images |
| 328 | self.size = size | 329 | self.size = size |
| 329 | self.dropout = dropout | 330 | self.dropout = dropout |
| @@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): | |||
| 344 | 345 | ||
| 345 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 346 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
| 346 | 347 | ||
| 348 | def get_input_ids(self, text: str): | ||
| 349 | return self.tokenizer(text, padding="do_not_pad").input_ids | ||
| 350 | |||
| 347 | def __len__(self): | 351 | def __len__(self): |
| 348 | return self.length_ | 352 | return self.length_ |
| 349 | 353 | ||
| @@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): | |||
| 404 | 408 | ||
| 405 | example = {} | 409 | example = {} |
| 406 | 410 | ||
| 407 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | 411 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
| 408 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | 412 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 409 | 413 | ||
| 410 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 414 | example["instance_prompt_ids"] = self.get_input_ids( |
| 411 | keywords_to_prompt(item.prompt, self.dropout, True) | 415 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 412 | ) | 416 | ) |
| 413 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 417 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 414 | 418 | ||
| 415 | if self.num_class_images != 0: | 419 | if self.num_class_images != 0: |
| 416 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | 420 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
| 417 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 421 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
| 418 | 422 | ||
| 419 | batch.append(example) | 423 | batch.append(example) |
