diff options
Diffstat (limited to 'data/textual_inversion')
-rw-r--r-- | data/textual_inversion/csv.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
43 | def prepare_data(self): | 43 | def prepare_data(self): |
44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
49 | 50 | ||
50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
110 | 111 | ||
111 | def get_example(self, i): | 112 | def get_example(self, i): |
112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
113 | 114 | ||
114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
@@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
122 | 123 | ||
123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
124 | 125 | ||
125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
127 | example["nprompts"] = nprompt | ||
126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
128 | text, | 130 | prompt, |
129 | padding="max_length", | 131 | padding="max_length", |
130 | truncation=True, | 132 | truncation=True, |
131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
@@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
141 | 143 | ||
142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
145 | 148 | ||