diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/dreambooth/csv.py | 15 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 17 |
2 files changed, 19 insertions, 13 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
49 | def prepare_data(self): | 49 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
52 | captions = [caption for caption in metadata['caption'].values] | 52 | prompts = metadata['prompt'].values |
53 | skips = [skip for skip in metadata['skip'].values] | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
54 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
55 | 56 | ||
56 | def setup(self, stage=None): | 57 | def setup(self, stage=None): |
57 | valid_set_size = int(len(self.data_full) * 0.2) | 58 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -135,7 +136,7 @@ class CSVDataset(Dataset): | |||
135 | return math.ceil(self._length / self.batch_size) * self.batch_size | 136 | return math.ceil(self._length / self.batch_size) * self.batch_size |
136 | 137 | ||
137 | def get_example(self, i): | 138 | def get_example(self, i): |
138 | image_path, text = self.data[i % self.num_instance_images] | 139 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
139 | 140 | ||
140 | if image_path in self.cache: | 141 | if image_path in self.cache: |
141 | return self.cache[image_path] | 142 | return self.cache[image_path] |
@@ -146,9 +147,10 @@ class CSVDataset(Dataset): | |||
146 | if not instance_image.mode == "RGB": | 147 | if not instance_image.mode == "RGB": |
147 | instance_image = instance_image.convert("RGB") | 148 | instance_image = instance_image.convert("RGB") |
148 | 149 | ||
149 | text = text.format(self.identifier) | 150 | prompt = prompt.format(self.identifier) |
150 | 151 | ||
151 | example["prompts"] = text | 152 | example["prompts"] = prompt |
153 | example["nprompts"] = nprompt | ||
152 | example["instance_images"] = instance_image | 154 | example["instance_images"] = instance_image |
153 | example["instance_prompt_ids"] = self.tokenizer( | 155 | example["instance_prompt_ids"] = self.tokenizer( |
154 | self.instance_prompt, | 156 | self.instance_prompt, |
@@ -178,6 +180,7 @@ class CSVDataset(Dataset): | |||
178 | unprocessed_example = self.get_example(i) | 180 | unprocessed_example = self.get_example(i) |
179 | 181 | ||
180 | example["prompts"] = unprocessed_example["prompts"] | 182 | example["prompts"] = unprocessed_example["prompts"] |
183 | example["nprompts"] = unprocessed_example["nprompts"] | ||
181 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 184 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
182 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 185 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
183 | 186 | ||
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
43 | def prepare_data(self): | 43 | def prepare_data(self): |
44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
49 | 50 | ||
50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
110 | 111 | ||
111 | def get_example(self, i): | 112 | def get_example(self, i): |
112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
113 | 114 | ||
114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
@@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
122 | 123 | ||
123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
124 | 125 | ||
125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
127 | example["nprompts"] = nprompt | ||
126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
128 | text, | 130 | prompt, |
129 | padding="max_length", | 131 | padding="max_length", |
130 | truncation=True, | 132 | truncation=True, |
131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
@@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
141 | 143 | ||
142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
145 | 148 | ||