From c90099f06e0b461660b326fb6d86b69d86e78289 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 14:47:01 +0200 Subject: Added negative prompt support for training scripts --- data/dreambooth/csv.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'data/dreambooth') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - captions = [caption for caption in metadata['caption'].values] - skips = [skip for skip in metadata['skip'].values] - self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + prompts = metadata['prompt'].values + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) + self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) @@ -135,7 +136,7 @@ class CSVDataset(Dataset): return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): - image_path, text = self.data[i % self.num_instance_images] + image_path, prompt, nprompt = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] @@ -146,9 +147,10 @@ class CSVDataset(Dataset): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - text = text.format(self.identifier) + prompt = prompt.format(self.identifier) - example["prompts"] = text + example["prompts"] = prompt + example["nprompts"] = nprompt example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, @@ -178,6 +180,7 @@ class CSVDataset(Dataset): unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] + example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] -- cgit v1.2.3-54-g00ecf