diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
| commit | c90099f06e0b461660b326fb6d86b69d86e78289 (patch) | |
| tree | df4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /data/textual_inversion | |
| parent | Fixed euler_a generator argument (diff) | |
| download | textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2 textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip | |
Added negative prompt support for training scripts
Diffstat (limited to 'data/textual_inversion')
| -rw-r--r-- | data/textual_inversion/csv.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 43 | def prepare_data(self): | 43 | def prepare_data(self): |
| 44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
| 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
| 47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
| 48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
| 49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
| 49 | 50 | ||
| 50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
| 51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
| @@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
| 109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 110 | 111 | ||
| 111 | def get_example(self, i): | 112 | def get_example(self, i): |
| 112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 113 | 114 | ||
| 114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
| 115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
| @@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
| 120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
| 121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
| 122 | 123 | ||
| 123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
| 124 | 125 | ||
| 125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
| 127 | example["nprompts"] = nprompt | ||
| 126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
| 127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
| 128 | text, | 130 | prompt, |
| 129 | padding="max_length", | 131 | padding="max_length", |
| 130 | truncation=True, | 132 | truncation=True, |
| 131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
| @@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
| 140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
| 141 | 143 | ||
| 142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
| 145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
| 143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
| 144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
| 145 | 148 | ||
