diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
49 | def prepare_data(self): | 49 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
52 | captions = [caption for caption in metadata['caption'].values] | 52 | prompts = metadata['prompt'].values |
53 | skips = [skip for skip in metadata['skip'].values] | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
54 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
55 | 56 | ||
56 | def setup(self, stage=None): | 57 | def setup(self, stage=None): |
57 | valid_set_size = int(len(self.data_full) * 0.2) | 58 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -135,7 +136,7 @@ class CSVDataset(Dataset): | |||
135 | return math.ceil(self._length / self.batch_size) * self.batch_size | 136 | return math.ceil(self._length / self.batch_size) * self.batch_size |
136 | 137 | ||
137 | def get_example(self, i): | 138 | def get_example(self, i): |
138 | image_path, text = self.data[i % self.num_instance_images] | 139 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
139 | 140 | ||
140 | if image_path in self.cache: | 141 | if image_path in self.cache: |
141 | return self.cache[image_path] | 142 | return self.cache[image_path] |
@@ -146,9 +147,10 @@ class CSVDataset(Dataset): | |||
146 | if not instance_image.mode == "RGB": | 147 | if not instance_image.mode == "RGB": |
147 | instance_image = instance_image.convert("RGB") | 148 | instance_image = instance_image.convert("RGB") |
148 | 149 | ||
149 | text = text.format(self.identifier) | 150 | prompt = prompt.format(self.identifier) |
150 | 151 | ||
151 | example["prompts"] = text | 152 | example["prompts"] = prompt |
153 | example["nprompts"] = nprompt | ||
152 | example["instance_images"] = instance_image | 154 | example["instance_images"] = instance_image |
153 | example["instance_prompt_ids"] = self.tokenizer( | 155 | example["instance_prompt_ids"] = self.tokenizer( |
154 | self.instance_prompt, | 156 | self.instance_prompt, |
@@ -178,6 +180,7 @@ class CSVDataset(Dataset): | |||
178 | unprocessed_example = self.get_example(i) | 180 | unprocessed_example = self.get_example(i) |
179 | 181 | ||
180 | example["prompts"] = unprocessed_example["prompts"] | 182 | example["prompts"] = unprocessed_example["prompts"] |
183 | example["nprompts"] = unprocessed_example["nprompts"] | ||
181 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 184 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
182 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 185 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
183 | 186 | ||