diff options
author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 /data | |
parent | Fix (diff) | |
download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip |
Improved class prompt handling
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 29 |
1 files changed, 12 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): | |||
38 | instance_image_path: Path | 38 | instance_image_path: Path |
39 | class_image_path: Path | 39 | class_image_path: Path |
40 | prompt: list[str] | 40 | prompt: list[str] |
41 | cprompt: str | ||
41 | nprompt: str | 42 | nprompt: str |
42 | 43 | ||
43 | 44 | ||
@@ -47,8 +48,6 @@ class CSVDataModule(): | |||
47 | batch_size: int, | 48 | batch_size: int, |
48 | data_file: str, | 49 | data_file: str, |
49 | prompt_processor: PromptProcessor, | 50 | prompt_processor: PromptProcessor, |
50 | instance_identifier: str, | ||
51 | class_identifier: Optional[str] = None, | ||
52 | class_subdir: str = "cls", | 51 | class_subdir: str = "cls", |
53 | num_class_images: int = 100, | 52 | num_class_images: int = 100, |
54 | size: int = 512, | 53 | size: int = 512, |
@@ -77,8 +76,6 @@ class CSVDataModule(): | |||
77 | self.num_class_images = num_class_images | 76 | self.num_class_images = num_class_images |
78 | 77 | ||
79 | self.prompt_processor = prompt_processor | 78 | self.prompt_processor = prompt_processor |
80 | self.instance_identifier = instance_identifier | ||
81 | self.class_identifier = class_identifier | ||
82 | self.size = size | 79 | self.size = size |
83 | self.repeats = repeats | 80 | self.repeats = repeats |
84 | self.dropout = dropout | 81 | self.dropout = dropout |
@@ -96,14 +93,18 @@ class CSVDataModule(): | |||
96 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 93 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
97 | image = template["image"] if "image" in template else "{}" | 94 | image = template["image"] if "image" in template else "{}" |
98 | prompt = template["prompt"] if "prompt" in template else "{content}" | 95 | prompt = template["prompt"] if "prompt" in template else "{content}" |
96 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | ||
99 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
100 | 98 | ||
101 | return [ | 99 | return [ |
102 | CSVDataItem( | 100 | CSVDataItem( |
103 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
104 | None, | 102 | None, |
105 | prompt_to_keywords(prompt.format( | 103 | prompt_to_keywords( |
106 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | 104 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
105 | expansions | ||
106 | ), | ||
107 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | ||
107 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
108 | ) | 109 | ) |
109 | for item in data | 110 | for item in data |
@@ -123,6 +124,7 @@ class CSVDataModule(): | |||
123 | item.instance_image_path, | 124 | item.instance_image_path, |
124 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 125 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
125 | item.prompt, | 126 | item.prompt, |
127 | item.cprompt, | ||
126 | item.nprompt, | 128 | item.nprompt, |
127 | ) | 129 | ) |
128 | for item in items | 130 | for item in items |
@@ -160,12 +162,10 @@ class CSVDataModule(): | |||
160 | 162 | ||
161 | def setup(self, stage=None): | 163 | def setup(self, stage=None): |
162 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 164 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
163 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | ||
164 | num_class_images=self.num_class_images, | 165 | num_class_images=self.num_class_images, |
165 | size=self.size, interpolation=self.interpolation, | 166 | size=self.size, interpolation=self.interpolation, |
166 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 167 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
167 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 168 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
168 | instance_identifier=self.instance_identifier, | ||
169 | size=self.size, interpolation=self.interpolation, | 169 | size=self.size, interpolation=self.interpolation, |
170 | center_crop=self.center_crop) | 170 | center_crop=self.center_crop) |
171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
@@ -187,9 +187,7 @@ class CSVDataset(Dataset): | |||
187 | self, | 187 | self, |
188 | data: List[CSVDataItem], | 188 | data: List[CSVDataItem], |
189 | prompt_processor: PromptProcessor, | 189 | prompt_processor: PromptProcessor, |
190 | instance_identifier: str, | ||
191 | batch_size: int = 1, | 190 | batch_size: int = 1, |
192 | class_identifier: Optional[str] = None, | ||
193 | num_class_images: int = 0, | 191 | num_class_images: int = 0, |
194 | size: int = 512, | 192 | size: int = 512, |
195 | repeats: int = 1, | 193 | repeats: int = 1, |
@@ -201,8 +199,6 @@ class CSVDataset(Dataset): | |||
201 | self.data = data | 199 | self.data = data |
202 | self.prompt_processor = prompt_processor | 200 | self.prompt_processor = prompt_processor |
203 | self.batch_size = batch_size | 201 | self.batch_size = batch_size |
204 | self.instance_identifier = instance_identifier | ||
205 | self.class_identifier = class_identifier | ||
206 | self.num_class_images = num_class_images | 202 | self.num_class_images = num_class_images |
207 | self.dropout = dropout | 203 | self.dropout = dropout |
208 | self.image_cache = {} | 204 | self.image_cache = {} |
@@ -239,14 +235,12 @@ class CSVDataset(Dataset): | |||
239 | 235 | ||
240 | return image | 236 | return image |
241 | 237 | ||
242 | def get_input_ids(self, prompt, identifier): | ||
243 | return self.prompt_processor.get_input_ids(prompt.format(identifier)) | ||
244 | |||
245 | def get_example(self, i): | 238 | def get_example(self, i): |
246 | item = self.data[i % self.num_instance_images] | 239 | item = self.data[i % self.num_instance_images] |
247 | 240 | ||
248 | example = {} | 241 | example = {} |
249 | example["prompts"] = item.prompt | 242 | example["prompts"] = item.prompt |
243 | example["cprompts"] = item.cprompt | ||
250 | example["nprompts"] = item.nprompt | 244 | example["nprompts"] = item.nprompt |
251 | example["instance_images"] = self.get_image(item.instance_image_path) | 245 | example["instance_images"] = self.get_image(item.instance_image_path) |
252 | if self.num_class_images != 0: | 246 | if self.num_class_images != 0: |
@@ -260,13 +254,14 @@ class CSVDataset(Dataset): | |||
260 | example = {} | 254 | example = {} |
261 | 255 | ||
262 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) | 256 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
257 | example["cprompts"] = unprocessed_example["cprompts"] | ||
263 | example["nprompts"] = unprocessed_example["nprompts"] | 258 | example["nprompts"] = unprocessed_example["nprompts"] |
264 | 259 | ||
265 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 260 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
266 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) | 261 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
267 | 262 | ||
268 | if self.num_class_images != 0: | 263 | if self.num_class_images != 0: |
269 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 264 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
270 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) | 265 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
271 | 266 | ||
272 | return example | 267 | return example |