From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 13 Jan 2023 13:49:35 +0100
Subject: Removed PromptProcessor, modularized training loop

---
 data/csv.py | 36 ++++++++++++++++++++----------------
 1 file changed, 20 insertions(+), 16 deletions(-)

(limited to 'data')

diff --git a/data/csv.py b/data/csv.py
index f5fc8e6..a3fef30 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -9,9 +9,10 @@ from PIL import Image
 
 from torch.utils.data import IterableDataset, DataLoader, random_split
 from torchvision import transforms
+from transformers import CLIPTokenizer
 
 from data.keywords import prompt_to_keywords, keywords_to_prompt
-from models.clip.prompt import PromptProcessor
+from models.clip.util import unify_input_ids
 
 
 image_cache: dict[str, Image.Image] = {}
@@ -102,7 +103,7 @@ def generate_buckets(
 def collate_fn(
     num_class_images: int,
     weight_dtype: torch.dtype,
-    prompt_processor: PromptProcessor,
+    tokenizer: CLIPTokenizer,
     examples
 ):
     prompt_ids = [example["prompt_ids"] for example in examples]
@@ -119,9 +120,9 @@ def collate_fn(
     pixel_values = torch.stack(pixel_values)
     pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
 
-    prompts = prompt_processor.unify_input_ids(prompt_ids)
-    nprompts = prompt_processor.unify_input_ids(nprompt_ids)
-    inputs = prompt_processor.unify_input_ids(input_ids)
+    prompts = unify_input_ids(tokenizer, prompt_ids)
+    nprompts = unify_input_ids(tokenizer, nprompt_ids)
+    inputs = unify_input_ids(tokenizer, input_ids)
 
     batch = {
         "prompt_ids": prompts.input_ids,
@@ -148,7 +149,7 @@ class VlpnDataModule():
         self,
         batch_size: int,
         data_file: str,
-        prompt_processor: PromptProcessor,
+        tokenizer: CLIPTokenizer,
         class_subdir: str = "cls",
         num_class_images: int = 1,
         size: int = 768,
@@ -179,7 +180,7 @@ class VlpnDataModule():
         self.class_root.mkdir(parents=True, exist_ok=True)
         self.num_class_images = num_class_images
 
-        self.prompt_processor = prompt_processor
+        self.tokenizer = tokenizer
         self.size = size
         self.num_buckets = num_buckets
         self.bucket_step_size = bucket_step_size
@@ -272,7 +273,7 @@ class VlpnDataModule():
         self.data_val = self.pad_items(data_val)
 
         train_dataset = VlpnDataset(
-            self.data_train, self.prompt_processor,
+            self.data_train, self.tokenizer,
             num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
             bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
             batch_size=self.batch_size, generator=generator,
@@ -281,7 +282,7 @@ class VlpnDataModule():
         )
 
         val_dataset = VlpnDataset(
-            self.data_val, self.prompt_processor,
+            self.data_val, self.tokenizer,
             num_buckets=self.num_buckets, progressive_buckets=True,
             bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
             repeat=self.valid_set_repeat,
@@ -289,7 +290,7 @@ class VlpnDataModule():
             size=self.size, interpolation=self.interpolation,
         )
 
-        collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor)
+        collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer)
 
         self.train_dataloader = DataLoader(
             train_dataset,
@@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset):
     def __init__(
         self,
         items: list[VlpnDataItem],
-        prompt_processor: PromptProcessor,
+        tokenizer: CLIPTokenizer,
         num_buckets: int = 1,
         bucket_step_size: int = 64,
         bucket_max_pixels: Optional[int] = None,
@@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset):
         self.items = items * repeat
         self.batch_size = batch_size
 
-        self.prompt_processor = prompt_processor
+        self.tokenizer = tokenizer
         self.num_class_images = num_class_images
         self.size = size
         self.dropout = dropout
@@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset):
 
         self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
 
+    def get_input_ids(self, text: str):
+        return self.tokenizer(text, padding="do_not_pad").input_ids
+
     def __len__(self):
         return self.length_
 
@@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset):
 
                 example = {}
 
-                example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
-                example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
+                example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
+                example["nprompt_ids"] = self.get_input_ids(item.nprompt)
 
-                example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
+                example["instance_prompt_ids"] = self.get_input_ids(
                     keywords_to_prompt(item.prompt, self.dropout, True)
                 )
                 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
 
                 if self.num_class_images != 0:
-                    example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
+                    example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
                     example["class_images"] = image_transforms(get_image(item.class_image_path))
 
                 batch.append(example)
-- 
cgit v1.2.3-70-g09d2