diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index 9125212..9c3c3f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
54 | dropout: float = 0, | 54 | dropout: float = 0, |
55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
57 | template_key: str = "template", | ||
57 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
58 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
60 | keyword_filter: list[str] = [], | ||
59 | collate_fn=None, | 61 | collate_fn=None, |
60 | num_workers: int = 0 | 62 | num_workers: int = 0 |
61 | ): | 63 | ): |
@@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule): | |||
78 | self.repeats = repeats | 80 | self.repeats = repeats |
79 | self.dropout = dropout | 81 | self.dropout = dropout |
80 | self.center_crop = center_crop | 82 | self.center_crop = center_crop |
83 | self.template_key = template_key | ||
81 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
82 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
83 | self.generator = generator | 86 | self.generator = generator |
87 | self.keyword_filter = keyword_filter | ||
84 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
85 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
86 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
87 | 91 | ||
88 | def prepare_subdata(self, template, data, num_class_images=1): | 92 | def prepare_items(self, template, data) -> list[CSVDataItem]: |
89 | image = template["image"] if "image" in template else "{}" | 93 | image = template["image"] if "image" in template else "{}" |
90 | prompt = template["prompt"] if "prompt" in template else "{content}" | 94 | prompt = template["prompt"] if "prompt" in template else "{content}" |
91 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 95 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
92 | 96 | ||
93 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
94 | |||
95 | return [ | 97 | return [ |
96 | CSVDataItem( | 98 | CSVDataItem( |
97 | self.data_root.joinpath(image.format(item["image"])), | 99 | self.data_root.joinpath(image.format(item["image"])), |
98 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), | 100 | None, |
99 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
100 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
101 | ) | 103 | ) |
102 | for item in data | 104 | for item in data |
105 | ] | ||
106 | |||
107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | ||
108 | if len(self.keyword_filter) == 0: | ||
109 | return items | ||
110 | |||
111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | ||
112 | |||
113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | ||
114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | ||
115 | |||
116 | return [ | ||
117 | CSVDataItem( | ||
118 | item.instance_image_path, | ||
119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | ||
120 | item.prompt, | ||
121 | item.nprompt | ||
122 | ) | ||
123 | for item in items | ||
103 | for i in range(image_multiplier) | 124 | for i in range(image_multiplier) |
104 | ] | 125 | ] |
105 | 126 | ||
106 | def prepare_data(self): | 127 | def prepare_data(self): |
107 | with open(self.data_file, 'rt') as f: | 128 | with open(self.data_file, 'rt') as f: |
108 | metadata = json.load(f) | 129 | metadata = json.load(f) |
109 | template = metadata["template"] if "template" in metadata else {} | 130 | template = metadata[self.template_key] if self.template_key in metadata else {} |
110 | items = metadata["items"] if "items" in metadata else [] | 131 | items = metadata["items"] if "items" in metadata else [] |
111 | 132 | ||
112 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 133 | items = [item for item in items if not "skip" in item or item["skip"] != True] |
134 | items = self.prepare_items(template, items) | ||
135 | items = self.filter_items(items) | ||
136 | |||
113 | num_images = len(items) | 137 | num_images = len(items) |
114 | 138 | ||
115 | valid_set_size = int(num_images * 0.1) | 139 | valid_set_size = int(num_images * 0.1) |
@@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
120 | 144 | ||
121 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 145 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
122 | 146 | ||
123 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) | 147 | self.data_train = self.pad_items(data_train, self.num_class_images) |
124 | self.data_val = self.prepare_subdata(template, data_val) | 148 | self.data_val = self.pad_items(data_val) |
125 | 149 | ||
126 | def setup(self, stage=None): | 150 | def setup(self, stage=None): |
127 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 151 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |