summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
commitc90099f06e0b461660b326fb6d86b69d86e78289 (patch)
treedf4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /data
parentFixed euler_a generator argument (diff)
downloadtextual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip
Added negative prompt support for training scripts
Diffstat (limited to 'data')
-rw-r--r--data/dreambooth/csv.py15
-rw-r--r--data/textual_inversion/csv.py17
2 files changed, 19 insertions, 13 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 08ed49c..71aa1eb 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule):
49 def prepare_data(self): 49 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 50 metadata = pd.read_csv(self.data_file)
51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
52 captions = [caption for caption in metadata['caption'].values] 52 prompts = metadata['prompt'].values
53 skips = [skip for skip in metadata['skip'].values] 53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
54 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
55 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
55 56
56 def setup(self, stage=None): 57 def setup(self, stage=None):
57 valid_set_size = int(len(self.data_full) * 0.2) 58 valid_set_size = int(len(self.data_full) * 0.2)
@@ -135,7 +136,7 @@ class CSVDataset(Dataset):
135 return math.ceil(self._length / self.batch_size) * self.batch_size 136 return math.ceil(self._length / self.batch_size) * self.batch_size
136 137
137 def get_example(self, i): 138 def get_example(self, i):
138 image_path, text = self.data[i % self.num_instance_images] 139 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
139 140
140 if image_path in self.cache: 141 if image_path in self.cache:
141 return self.cache[image_path] 142 return self.cache[image_path]
@@ -146,9 +147,10 @@ class CSVDataset(Dataset):
146 if not instance_image.mode == "RGB": 147 if not instance_image.mode == "RGB":
147 instance_image = instance_image.convert("RGB") 148 instance_image = instance_image.convert("RGB")
148 149
149 text = text.format(self.identifier) 150 prompt = prompt.format(self.identifier)
150 151
151 example["prompts"] = text 152 example["prompts"] = prompt
153 example["nprompts"] = nprompt
152 example["instance_images"] = instance_image 154 example["instance_images"] = instance_image
153 example["instance_prompt_ids"] = self.tokenizer( 155 example["instance_prompt_ids"] = self.tokenizer(
154 self.instance_prompt, 156 self.instance_prompt,
@@ -178,6 +180,7 @@ class CSVDataset(Dataset):
178 unprocessed_example = self.get_example(i) 180 unprocessed_example = self.get_example(i)
179 181
180 example["prompts"] = unprocessed_example["prompts"] 182 example["prompts"] = unprocessed_example["prompts"]
183 example["nprompts"] = unprocessed_example["nprompts"]
181 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 184 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
182 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 185 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
183 186
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 3ac57df..64f0c28 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule):
43 def prepare_data(self): 43 def prepare_data(self):
44 metadata = pd.read_csv(self.data_file) 44 metadata = pd.read_csv(self.data_file)
45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
46 captions = [caption for caption in metadata['caption'].values] 46 prompts = metadata['prompt'].values
47 skips = [skip for skip in metadata['skip'].values] 47 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
48 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 48 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
49 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
49 50
50 def setup(self, stage=None): 51 def setup(self, stage=None):
51 valid_set_size = int(len(self.data_full) * 0.2) 52 valid_set_size = int(len(self.data_full) * 0.2)
@@ -109,7 +110,7 @@ class CSVDataset(Dataset):
109 return math.ceil(self._length / self.batch_size) * self.batch_size 110 return math.ceil(self._length / self.batch_size) * self.batch_size
110 111
111 def get_example(self, i): 112 def get_example(self, i):
112 image_path, text = self.data[i % self.num_instance_images] 113 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
113 114
114 if image_path in self.cache: 115 if image_path in self.cache:
115 return self.cache[image_path] 116 return self.cache[image_path]
@@ -120,12 +121,13 @@ class CSVDataset(Dataset):
120 if not instance_image.mode == "RGB": 121 if not instance_image.mode == "RGB":
121 instance_image = instance_image.convert("RGB") 122 instance_image = instance_image.convert("RGB")
122 123
123 text = text.format(self.placeholder_token) 124 prompt = prompt.format(self.placeholder_token)
124 125
125 example["prompts"] = text 126 example["prompts"] = prompt
127 example["nprompts"] = nprompt
126 example["pixel_values"] = instance_image 128 example["pixel_values"] = instance_image
127 example["input_ids"] = self.tokenizer( 129 example["input_ids"] = self.tokenizer(
128 text, 130 prompt,
129 padding="max_length", 131 padding="max_length",
130 truncation=True, 132 truncation=True,
131 max_length=self.tokenizer.model_max_length, 133 max_length=self.tokenizer.model_max_length,
@@ -140,6 +142,7 @@ class CSVDataset(Dataset):
140 unprocessed_example = self.get_example(i) 142 unprocessed_example = self.get_example(i)
141 143
142 example["prompts"] = unprocessed_example["prompts"] 144 example["prompts"] = unprocessed_example["prompts"]
145 example["nprompts"] = unprocessed_example["nprompts"]
143 example["input_ids"] = unprocessed_example["input_ids"] 146 example["input_ids"] = unprocessed_example["input_ids"]
144 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) 147 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
145 148