From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- data/csv.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py @@ -9,9 +9,10 @@ from PIL import Image from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms +from transformers import CLIPTokenizer from data.keywords import prompt_to_keywords, keywords_to_prompt -from models.clip.prompt import PromptProcessor +from models.clip.util import unify_input_ids image_cache: dict[str, Image.Image] = {} @@ -102,7 +103,7 @@ def generate_buckets( def collate_fn( num_class_images: int, weight_dtype: torch.dtype, - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, examples ): prompt_ids = [example["prompt_ids"] for example in examples] @@ -119,9 +120,9 @@ def collate_fn( pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - prompts = prompt_processor.unify_input_ids(prompt_ids) - nprompts = prompt_processor.unify_input_ids(nprompt_ids) - inputs = prompt_processor.unify_input_ids(input_ids) + prompts = unify_input_ids(tokenizer, prompt_ids) + nprompts = unify_input_ids(tokenizer, nprompt_ids) + inputs = unify_input_ids(tokenizer, input_ids) batch = { "prompt_ids": prompts.input_ids, @@ -148,7 +149,7 @@ class VlpnDataModule(): self, batch_size: int, data_file: str, - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, @@ -179,7 +180,7 @@ class VlpnDataModule(): self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images - self.prompt_processor = prompt_processor + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets self.bucket_step_size = bucket_step_size @@ -272,7 +273,7 @@ class VlpnDataModule(): self.data_val = self.pad_items(data_val) train_dataset = VlpnDataset( - self.data_train, self.prompt_processor, + self.data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, @@ -281,7 +282,7 @@ class VlpnDataModule(): ) val_dataset = VlpnDataset( - self.data_val, self.prompt_processor, + self.data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, repeat=self.valid_set_repeat, @@ -289,7 +290,7 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) + collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) self.train_dataloader = DataLoader( train_dataset, @@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): def __init__( self, items: list[VlpnDataItem], - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, num_buckets: int = 1, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, @@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): self.items = items * repeat self.batch_size = batch_size - self.prompt_processor = prompt_processor + self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size self.dropout = dropout @@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() + def get_input_ids(self, text: str): + return self.tokenizer(text, padding="do_not_pad").input_ids + def __len__(self): return self.length_ @@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): example = {} - example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) + example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + example["instance_prompt_ids"] = self.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) + example["class_prompt_ids"] = self.get_input_ids(item.cprompt) example["class_images"] = image_transforms(get_image(item.class_image_path)) batch.append(example) -- cgit v1.2.3-54-g00ecf