diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index 9125212..9c3c3f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 54 | dropout: float = 0, | 54 | dropout: float = 0, |
| 55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
| 56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
| 57 | template_key: str = "template", | ||
| 57 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
| 58 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
| 60 | keyword_filter: list[str] = [], | ||
| 59 | collate_fn=None, | 61 | collate_fn=None, |
| 60 | num_workers: int = 0 | 62 | num_workers: int = 0 |
| 61 | ): | 63 | ): |
| @@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 78 | self.repeats = repeats | 80 | self.repeats = repeats |
| 79 | self.dropout = dropout | 81 | self.dropout = dropout |
| 80 | self.center_crop = center_crop | 82 | self.center_crop = center_crop |
| 83 | self.template_key = template_key | ||
| 81 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
| 82 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
| 83 | self.generator = generator | 86 | self.generator = generator |
| 87 | self.keyword_filter = keyword_filter | ||
| 84 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
| 85 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
| 86 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
| 87 | 91 | ||
| 88 | def prepare_subdata(self, template, data, num_class_images=1): | 92 | def prepare_items(self, template, data) -> list[CSVDataItem]: |
| 89 | image = template["image"] if "image" in template else "{}" | 93 | image = template["image"] if "image" in template else "{}" |
| 90 | prompt = template["prompt"] if "prompt" in template else "{content}" | 94 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 91 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 95 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 92 | 96 | ||
| 93 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
| 94 | |||
| 95 | return [ | 97 | return [ |
| 96 | CSVDataItem( | 98 | CSVDataItem( |
| 97 | self.data_root.joinpath(image.format(item["image"])), | 99 | self.data_root.joinpath(image.format(item["image"])), |
| 98 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), | 100 | None, |
| 99 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 100 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
| 101 | ) | 103 | ) |
| 102 | for item in data | 104 | for item in data |
| 105 | ] | ||
| 106 | |||
| 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | ||
| 108 | if len(self.keyword_filter) == 0: | ||
| 109 | return items | ||
| 110 | |||
| 111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | ||
| 112 | |||
| 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | ||
| 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | ||
| 115 | |||
| 116 | return [ | ||
| 117 | CSVDataItem( | ||
| 118 | item.instance_image_path, | ||
| 119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | ||
| 120 | item.prompt, | ||
| 121 | item.nprompt | ||
| 122 | ) | ||
| 123 | for item in items | ||
| 103 | for i in range(image_multiplier) | 124 | for i in range(image_multiplier) |
| 104 | ] | 125 | ] |
| 105 | 126 | ||
| 106 | def prepare_data(self): | 127 | def prepare_data(self): |
| 107 | with open(self.data_file, 'rt') as f: | 128 | with open(self.data_file, 'rt') as f: |
| 108 | metadata = json.load(f) | 129 | metadata = json.load(f) |
| 109 | template = metadata["template"] if "template" in metadata else {} | 130 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 110 | items = metadata["items"] if "items" in metadata else [] | 131 | items = metadata["items"] if "items" in metadata else [] |
| 111 | 132 | ||
| 112 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 133 | items = [item for item in items if not "skip" in item or item["skip"] != True] |
| 134 | items = self.prepare_items(template, items) | ||
| 135 | items = self.filter_items(items) | ||
| 136 | |||
| 113 | num_images = len(items) | 137 | num_images = len(items) |
| 114 | 138 | ||
| 115 | valid_set_size = int(num_images * 0.1) | 139 | valid_set_size = int(num_images * 0.1) |
| @@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 120 | 144 | ||
| 121 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 145 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
| 122 | 146 | ||
| 123 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) | 147 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 124 | self.data_val = self.prepare_subdata(template, data_val) | 148 | self.data_val = self.pad_items(data_val) |
| 125 | 149 | ||
| 126 | def setup(self, stage=None): | 150 | def setup(self, stage=None): |
| 127 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 151 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
