From b73469706091c8aaf3f028de96ab017f5a845639 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 20:49:57 +0100 Subject: Optimized Textual Inversion training by filtering dataset by existence of added tokens --- data/csv.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 9125212..9c3c3f8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule): dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, + template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, + keyword_filter: list[str] = [], collate_fn=None, num_workers: int = 0 ): @@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule): self.repeats = repeats self.dropout = dropout self.center_crop = center_crop + self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator + self.keyword_filter = keyword_filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size - def prepare_subdata(self, template, data, num_class_images=1): + def prepare_items(self, template, data) -> list[CSVDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" - image_multiplier = max(math.ceil(num_class_images / len(data)), 1) - return [ CSVDataItem( self.data_root.joinpath(image.format(item["image"])), - self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), + None, prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) ) for item in data + ] + + def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: + if len(self.keyword_filter) == 0: + return items + + return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] + + def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: + image_multiplier = max(math.ceil(num_class_images / len(items)), 1) + + return [ + CSVDataItem( + item.instance_image_path, + self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), + item.prompt, + item.nprompt + ) + for item in items for i in range(image_multiplier) ] def prepare_data(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) - template = metadata["template"] if "template" in metadata else {} + template = metadata[self.template_key] if self.template_key in metadata else {} items = metadata["items"] if "items" in metadata else [] items = [item for item in items if not "skip" in item or item["skip"] != True] + items = self.prepare_items(template, items) + items = self.filter_items(items) + num_images = len(items) valid_set_size = int(num_images * 0.1) @@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule): data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) - self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) - self.data_val = self.prepare_subdata(template, data_val) + self.data_train = self.pad_items(data_train, self.num_class_images) + self.data_val = self.pad_items(data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, -- cgit v1.2.3-70-g09d2