diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 29 |
1 files changed, 12 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): | |||
| 38 | instance_image_path: Path | 38 | instance_image_path: Path |
| 39 | class_image_path: Path | 39 | class_image_path: Path |
| 40 | prompt: list[str] | 40 | prompt: list[str] |
| 41 | cprompt: str | ||
| 41 | nprompt: str | 42 | nprompt: str |
| 42 | 43 | ||
| 43 | 44 | ||
| @@ -47,8 +48,6 @@ class CSVDataModule(): | |||
| 47 | batch_size: int, | 48 | batch_size: int, |
| 48 | data_file: str, | 49 | data_file: str, |
| 49 | prompt_processor: PromptProcessor, | 50 | prompt_processor: PromptProcessor, |
| 50 | instance_identifier: str, | ||
| 51 | class_identifier: Optional[str] = None, | ||
| 52 | class_subdir: str = "cls", | 51 | class_subdir: str = "cls", |
| 53 | num_class_images: int = 100, | 52 | num_class_images: int = 100, |
| 54 | size: int = 512, | 53 | size: int = 512, |
| @@ -77,8 +76,6 @@ class CSVDataModule(): | |||
| 77 | self.num_class_images = num_class_images | 76 | self.num_class_images = num_class_images |
| 78 | 77 | ||
| 79 | self.prompt_processor = prompt_processor | 78 | self.prompt_processor = prompt_processor |
| 80 | self.instance_identifier = instance_identifier | ||
| 81 | self.class_identifier = class_identifier | ||
| 82 | self.size = size | 79 | self.size = size |
| 83 | self.repeats = repeats | 80 | self.repeats = repeats |
| 84 | self.dropout = dropout | 81 | self.dropout = dropout |
| @@ -96,14 +93,18 @@ class CSVDataModule(): | |||
| 96 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 93 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
| 97 | image = template["image"] if "image" in template else "{}" | 94 | image = template["image"] if "image" in template else "{}" |
| 98 | prompt = template["prompt"] if "prompt" in template else "{content}" | 95 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 96 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | ||
| 99 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 100 | 98 | ||
| 101 | return [ | 99 | return [ |
| 102 | CSVDataItem( | 100 | CSVDataItem( |
| 103 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
| 104 | None, | 102 | None, |
| 105 | prompt_to_keywords(prompt.format( | 103 | prompt_to_keywords( |
| 106 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | 104 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 105 | expansions | ||
| 106 | ), | ||
| 107 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | ||
| 107 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 108 | ) | 109 | ) |
| 109 | for item in data | 110 | for item in data |
| @@ -123,6 +124,7 @@ class CSVDataModule(): | |||
| 123 | item.instance_image_path, | 124 | item.instance_image_path, |
| 124 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 125 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
| 125 | item.prompt, | 126 | item.prompt, |
| 127 | item.cprompt, | ||
| 126 | item.nprompt, | 128 | item.nprompt, |
| 127 | ) | 129 | ) |
| 128 | for item in items | 130 | for item in items |
| @@ -160,12 +162,10 @@ class CSVDataModule(): | |||
| 160 | 162 | ||
| 161 | def setup(self, stage=None): | 163 | def setup(self, stage=None): |
| 162 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 164 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| 163 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | ||
| 164 | num_class_images=self.num_class_images, | 165 | num_class_images=self.num_class_images, |
| 165 | size=self.size, interpolation=self.interpolation, | 166 | size=self.size, interpolation=self.interpolation, |
| 166 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 167 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
| 167 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 168 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 168 | instance_identifier=self.instance_identifier, | ||
| 169 | size=self.size, interpolation=self.interpolation, | 169 | size=self.size, interpolation=self.interpolation, |
| 170 | center_crop=self.center_crop) | 170 | center_crop=self.center_crop) |
| 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| @@ -187,9 +187,7 @@ class CSVDataset(Dataset): | |||
| 187 | self, | 187 | self, |
| 188 | data: List[CSVDataItem], | 188 | data: List[CSVDataItem], |
| 189 | prompt_processor: PromptProcessor, | 189 | prompt_processor: PromptProcessor, |
| 190 | instance_identifier: str, | ||
| 191 | batch_size: int = 1, | 190 | batch_size: int = 1, |
| 192 | class_identifier: Optional[str] = None, | ||
| 193 | num_class_images: int = 0, | 191 | num_class_images: int = 0, |
| 194 | size: int = 512, | 192 | size: int = 512, |
| 195 | repeats: int = 1, | 193 | repeats: int = 1, |
| @@ -201,8 +199,6 @@ class CSVDataset(Dataset): | |||
| 201 | self.data = data | 199 | self.data = data |
| 202 | self.prompt_processor = prompt_processor | 200 | self.prompt_processor = prompt_processor |
| 203 | self.batch_size = batch_size | 201 | self.batch_size = batch_size |
| 204 | self.instance_identifier = instance_identifier | ||
| 205 | self.class_identifier = class_identifier | ||
| 206 | self.num_class_images = num_class_images | 202 | self.num_class_images = num_class_images |
| 207 | self.dropout = dropout | 203 | self.dropout = dropout |
| 208 | self.image_cache = {} | 204 | self.image_cache = {} |
| @@ -239,14 +235,12 @@ class CSVDataset(Dataset): | |||
| 239 | 235 | ||
| 240 | return image | 236 | return image |
| 241 | 237 | ||
| 242 | def get_input_ids(self, prompt, identifier): | ||
| 243 | return self.prompt_processor.get_input_ids(prompt.format(identifier)) | ||
| 244 | |||
| 245 | def get_example(self, i): | 238 | def get_example(self, i): |
| 246 | item = self.data[i % self.num_instance_images] | 239 | item = self.data[i % self.num_instance_images] |
| 247 | 240 | ||
| 248 | example = {} | 241 | example = {} |
| 249 | example["prompts"] = item.prompt | 242 | example["prompts"] = item.prompt |
| 243 | example["cprompts"] = item.cprompt | ||
| 250 | example["nprompts"] = item.nprompt | 244 | example["nprompts"] = item.nprompt |
| 251 | example["instance_images"] = self.get_image(item.instance_image_path) | 245 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 252 | if self.num_class_images != 0: | 246 | if self.num_class_images != 0: |
| @@ -260,13 +254,14 @@ class CSVDataset(Dataset): | |||
| 260 | example = {} | 254 | example = {} |
| 261 | 255 | ||
| 262 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) | 256 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
| 257 | example["cprompts"] = unprocessed_example["cprompts"] | ||
| 263 | example["nprompts"] = unprocessed_example["nprompts"] | 258 | example["nprompts"] = unprocessed_example["nprompts"] |
| 264 | 259 | ||
| 265 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 260 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 266 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) | 261 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
| 267 | 262 | ||
| 268 | if self.num_class_images != 0: | 263 | if self.num_class_images != 0: |
| 269 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 264 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 270 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) | 265 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
| 271 | 266 | ||
| 272 | return example | 267 | return example |
