diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/dreambooth/csv.py | 15 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 17 |
2 files changed, 19 insertions, 13 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 49 | def prepare_data(self): | 49 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
| 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 52 | captions = [caption for caption in metadata['caption'].values] | 52 | prompts = metadata['prompt'].values |
| 53 | skips = [skip for skip in metadata['skip'].values] | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
| 54 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
| 55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
| 55 | 56 | ||
| 56 | def setup(self, stage=None): | 57 | def setup(self, stage=None): |
| 57 | valid_set_size = int(len(self.data_full) * 0.2) | 58 | valid_set_size = int(len(self.data_full) * 0.2) |
| @@ -135,7 +136,7 @@ class CSVDataset(Dataset): | |||
| 135 | return math.ceil(self._length / self.batch_size) * self.batch_size | 136 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 136 | 137 | ||
| 137 | def get_example(self, i): | 138 | def get_example(self, i): |
| 138 | image_path, text = self.data[i % self.num_instance_images] | 139 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 139 | 140 | ||
| 140 | if image_path in self.cache: | 141 | if image_path in self.cache: |
| 141 | return self.cache[image_path] | 142 | return self.cache[image_path] |
| @@ -146,9 +147,10 @@ class CSVDataset(Dataset): | |||
| 146 | if not instance_image.mode == "RGB": | 147 | if not instance_image.mode == "RGB": |
| 147 | instance_image = instance_image.convert("RGB") | 148 | instance_image = instance_image.convert("RGB") |
| 148 | 149 | ||
| 149 | text = text.format(self.identifier) | 150 | prompt = prompt.format(self.identifier) |
| 150 | 151 | ||
| 151 | example["prompts"] = text | 152 | example["prompts"] = prompt |
| 153 | example["nprompts"] = nprompt | ||
| 152 | example["instance_images"] = instance_image | 154 | example["instance_images"] = instance_image |
| 153 | example["instance_prompt_ids"] = self.tokenizer( | 155 | example["instance_prompt_ids"] = self.tokenizer( |
| 154 | self.instance_prompt, | 156 | self.instance_prompt, |
| @@ -178,6 +180,7 @@ class CSVDataset(Dataset): | |||
| 178 | unprocessed_example = self.get_example(i) | 180 | unprocessed_example = self.get_example(i) |
| 179 | 181 | ||
| 180 | example["prompts"] = unprocessed_example["prompts"] | 182 | example["prompts"] = unprocessed_example["prompts"] |
| 183 | example["nprompts"] = unprocessed_example["nprompts"] | ||
| 181 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 184 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 182 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 185 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 183 | 186 | ||
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 43 | def prepare_data(self): | 43 | def prepare_data(self): |
| 44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
| 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
| 47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
| 48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
| 49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
| 49 | 50 | ||
| 50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
| 51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
| @@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
| 109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 110 | 111 | ||
| 111 | def get_example(self, i): | 112 | def get_example(self, i): |
| 112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 113 | 114 | ||
| 114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
| 115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
| @@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
| 120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
| 121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
| 122 | 123 | ||
| 123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
| 124 | 125 | ||
| 125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
| 127 | example["nprompts"] = nprompt | ||
| 126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
| 127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
| 128 | text, | 130 | prompt, |
| 129 | padding="max_length", | 131 | padding="max_length", |
| 130 | truncation=True, | 132 | truncation=True, |
| 131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
| @@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
| 140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
| 141 | 143 | ||
| 142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
| 145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
| 143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
| 144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
| 145 | 148 | ||
