diff options
author | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
commit | 728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch) | |
tree | 9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /data | |
parent | Upstream updates; better handling of textual embedding (diff) | |
download | textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2 textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip |
Implemented extended prompt limit
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py index aad970c..316c099 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
72 | ] | 72 | ] |
73 | 73 | ||
74 | def prepare_data(self): | 74 | def prepare_data(self): |
75 | metadata = pd.read_csv(self.data_file) | 75 | metadata = pd.read_json(self.data_file) |
76 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] | 76 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] |
77 | num_images = len(metadata) | 77 | num_images = len(metadata) |
78 | 78 | ||
79 | valid_set_size = int(num_images * 0.2) | 79 | valid_set_size = int(num_images * 0.2) |
@@ -163,6 +163,12 @@ class CSVDataset(Dataset): | |||
163 | 163 | ||
164 | example = {} | 164 | example = {} |
165 | 165 | ||
166 | if isinstance(item.prompt, str): | ||
167 | item.prompt = [item.prompt] | ||
168 | |||
169 | if isinstance(item.nprompt, str): | ||
170 | item.nprompt = [item.nprompt] | ||
171 | |||
166 | example["prompts"] = item.prompt | 172 | example["prompts"] = item.prompt |
167 | example["nprompts"] = item.nprompt | 173 | example["nprompts"] = item.nprompt |
168 | 174 | ||
@@ -177,7 +183,7 @@ class CSVDataset(Dataset): | |||
177 | example["instance_images"] = instance_image | 183 | example["instance_images"] = instance_image |
178 | example["instance_prompt_ids"] = self.tokenizer( | 184 | example["instance_prompt_ids"] = self.tokenizer( |
179 | item.prompt.format(self.instance_identifier), | 185 | item.prompt.format(self.instance_identifier), |
180 | padding="do_not_pad", | 186 | padding="max_length", |
181 | truncation=True, | 187 | truncation=True, |
182 | max_length=self.tokenizer.model_max_length, | 188 | max_length=self.tokenizer.model_max_length, |
183 | ).input_ids | 189 | ).input_ids |
@@ -190,7 +196,7 @@ class CSVDataset(Dataset): | |||
190 | example["class_images"] = class_image | 196 | example["class_images"] = class_image |
191 | example["class_prompt_ids"] = self.tokenizer( | 197 | example["class_prompt_ids"] = self.tokenizer( |
192 | item.prompt.format(self.class_identifier), | 198 | item.prompt.format(self.class_identifier), |
193 | padding="do_not_pad", | 199 | padding="max_length", |
194 | truncation=True, | 200 | truncation=True, |
195 | max_length=self.tokenizer.model_max_length, | 201 | max_length=self.tokenizer.model_max_length, |
196 | ).input_ids | 202 | ).input_ids |