From c90099f06e0b461660b326fb6d86b69d86e78289 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 3 Oct 2022 14:47:01 +0200
Subject: Added negative prompt support for training scripts

---
 data/dreambooth/csv.py        | 15 +++++++++------
 data/textual_inversion/csv.py | 17 ++++++++++-------
 2 files changed, 19 insertions(+), 13 deletions(-)

(limited to 'data')

diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 08ed49c..71aa1eb 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule):
     def prepare_data(self):
         metadata = pd.read_csv(self.data_file)
         image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
-        captions = [caption for caption in metadata['caption'].values]
-        skips = [skip for skip in metadata['skip'].values]
-        self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
+        prompts = metadata['prompt'].values
+        nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
+        skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
+        self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
 
     def setup(self, stage=None):
         valid_set_size = int(len(self.data_full) * 0.2)
@@ -135,7 +136,7 @@ class CSVDataset(Dataset):
         return math.ceil(self._length / self.batch_size) * self.batch_size
 
     def get_example(self, i):
-        image_path, text = self.data[i % self.num_instance_images]
+        image_path, prompt, nprompt = self.data[i % self.num_instance_images]
 
         if image_path in self.cache:
             return self.cache[image_path]
@@ -146,9 +147,10 @@ class CSVDataset(Dataset):
         if not instance_image.mode == "RGB":
             instance_image = instance_image.convert("RGB")
 
-        text = text.format(self.identifier)
+        prompt = prompt.format(self.identifier)
 
-        example["prompts"] = text
+        example["prompts"] = prompt
+        example["nprompts"] = nprompt
         example["instance_images"] = instance_image
         example["instance_prompt_ids"] = self.tokenizer(
             self.instance_prompt,
@@ -178,6 +180,7 @@ class CSVDataset(Dataset):
         unprocessed_example = self.get_example(i)
 
         example["prompts"] = unprocessed_example["prompts"]
+        example["nprompts"] = unprocessed_example["nprompts"]
         example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
         example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
 
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 3ac57df..64f0c28 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule):
     def prepare_data(self):
         metadata = pd.read_csv(self.data_file)
         image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
-        captions = [caption for caption in metadata['caption'].values]
-        skips = [skip for skip in metadata['skip'].values]
-        self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
+        prompts = metadata['prompt'].values
+        nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
+        skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
+        self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
 
     def setup(self, stage=None):
         valid_set_size = int(len(self.data_full) * 0.2)
@@ -109,7 +110,7 @@ class CSVDataset(Dataset):
         return math.ceil(self._length / self.batch_size) * self.batch_size
 
     def get_example(self, i):
-        image_path, text = self.data[i % self.num_instance_images]
+        image_path, prompt, nprompt = self.data[i % self.num_instance_images]
 
         if image_path in self.cache:
             return self.cache[image_path]
@@ -120,12 +121,13 @@ class CSVDataset(Dataset):
         if not instance_image.mode == "RGB":
             instance_image = instance_image.convert("RGB")
 
-        text = text.format(self.placeholder_token)
+        prompt = prompt.format(self.placeholder_token)
 
-        example["prompts"] = text
+        example["prompts"] = prompt
+        example["nprompts"] = nprompt
         example["pixel_values"] = instance_image
         example["input_ids"] = self.tokenizer(
-            text,
+            prompt,
             padding="max_length",
             truncation=True,
             max_length=self.tokenizer.model_max_length,
@@ -140,6 +142,7 @@ class CSVDataset(Dataset):
         unprocessed_example = self.get_example(i)
 
         example["prompts"] = unprocessed_example["prompts"]
+        example["nprompts"] = unprocessed_example["nprompts"]
         example["input_ids"] = unprocessed_example["input_ids"]
         example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
 
-- 
cgit v1.2.3-70-g09d2