diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 144 |
1 files changed, 88 insertions, 56 deletions
diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor | |||
11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
12 | 12 | ||
13 | 13 | ||
14 | image_cache: dict[str, Image.Image] = {} | ||
15 | |||
16 | |||
17 | def get_image(path): | ||
18 | if path in image_cache: | ||
19 | return image_cache[path] | ||
20 | |||
21 | image = Image.open(path) | ||
22 | if not image.mode == "RGB": | ||
23 | image = image.convert("RGB") | ||
24 | image_cache[path] = image | ||
25 | |||
26 | return image | ||
27 | |||
28 | |||
14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 29 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): |
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 30 | return {"content": prompt} if isinstance(prompt, str) else prompt |
16 | 31 | ||
17 | 32 | ||
18 | class CSVDataItem(NamedTuple): | 33 | class VlpnDataItem(NamedTuple): |
19 | instance_image_path: Path | 34 | instance_image_path: Path |
20 | class_image_path: Path | 35 | class_image_path: Path |
21 | prompt: list[str] | 36 | prompt: list[str] |
@@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): | |||
24 | collection: list[str] | 39 | collection: list[str] |
25 | 40 | ||
26 | 41 | ||
27 | class CSVDataModule(): | 42 | class VlpnDataBucket(): |
43 | def __init__(self, width: int, height: int): | ||
44 | self.width = width | ||
45 | self.height = height | ||
46 | self.ratio = width / height | ||
47 | self.items: list[VlpnDataItem] = [] | ||
48 | |||
49 | |||
50 | class VlpnDataModule(): | ||
28 | def __init__( | 51 | def __init__( |
29 | self, | 52 | self, |
30 | batch_size: int, | 53 | batch_size: int, |
@@ -36,11 +59,10 @@ class CSVDataModule(): | |||
36 | repeats: int = 1, | 59 | repeats: int = 1, |
37 | dropout: float = 0, | 60 | dropout: float = 0, |
38 | interpolation: str = "bicubic", | 61 | interpolation: str = "bicubic", |
39 | center_crop: bool = False, | ||
40 | template_key: str = "template", | 62 | template_key: str = "template", |
41 | valid_set_size: Optional[int] = None, | 63 | valid_set_size: Optional[int] = None, |
42 | seed: Optional[int] = None, | 64 | seed: Optional[int] = None, |
43 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 65 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
44 | collate_fn=None, | 66 | collate_fn=None, |
45 | num_workers: int = 0 | 67 | num_workers: int = 0 |
46 | ): | 68 | ): |
@@ -60,7 +82,6 @@ class CSVDataModule(): | |||
60 | self.size = size | 82 | self.size = size |
61 | self.repeats = repeats | 83 | self.repeats = repeats |
62 | self.dropout = dropout | 84 | self.dropout = dropout |
63 | self.center_crop = center_crop | ||
64 | self.template_key = template_key | 85 | self.template_key = template_key |
65 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
66 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
@@ -70,14 +91,14 @@ class CSVDataModule(): | |||
70 | self.num_workers = num_workers | 91 | self.num_workers = num_workers |
71 | self.batch_size = batch_size | 92 | self.batch_size = batch_size |
72 | 93 | ||
73 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 94 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
74 | image = template["image"] if "image" in template else "{}" | 95 | image = template["image"] if "image" in template else "{}" |
75 | prompt = template["prompt"] if "prompt" in template else "{content}" | 96 | prompt = template["prompt"] if "prompt" in template else "{content}" |
76 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 97 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
77 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 98 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
78 | 99 | ||
79 | return [ | 100 | return [ |
80 | CSVDataItem( | 101 | VlpnDataItem( |
81 | self.data_root.joinpath(image.format(item["image"])), | 102 | self.data_root.joinpath(image.format(item["image"])), |
82 | None, | 103 | None, |
83 | prompt_to_keywords( | 104 | prompt_to_keywords( |
@@ -97,17 +118,17 @@ class CSVDataModule(): | |||
97 | for item in data | 118 | for item in data |
98 | ] | 119 | ] |
99 | 120 | ||
100 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 121 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: |
101 | if self.filter is None: | 122 | if self.filter is None: |
102 | return items | 123 | return items |
103 | 124 | ||
104 | return [item for item in items if self.filter(item)] | 125 | return [item for item in items if self.filter(item)] |
105 | 126 | ||
106 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 127 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: |
107 | image_multiplier = max(num_class_images, 1) | 128 | image_multiplier = max(num_class_images, 1) |
108 | 129 | ||
109 | return [ | 130 | return [ |
110 | CSVDataItem( | 131 | VlpnDataItem( |
111 | item.instance_image_path, | 132 | item.instance_image_path, |
112 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 133 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
113 | item.prompt, | 134 | item.prompt, |
@@ -119,7 +140,30 @@ class CSVDataModule(): | |||
119 | for i in range(image_multiplier) | 140 | for i in range(image_multiplier) |
120 | ] | 141 | ] |
121 | 142 | ||
122 | def prepare_data(self): | 143 | def generate_buckets(self, items: list[VlpnDataItem]): |
144 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
145 | |||
146 | for i in range(1, 5): | ||
147 | s = self.size + i * 64 | ||
148 | buckets.append(VlpnDataBucket(s, self.size)) | ||
149 | buckets.append(VlpnDataBucket(self.size, s)) | ||
150 | |||
151 | for item in items: | ||
152 | image = get_image(item.instance_image_path) | ||
153 | ratio = image.width / image.height | ||
154 | |||
155 | if ratio >= 1: | ||
156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | ||
157 | else: | ||
158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | ||
159 | |||
160 | for bucket in candidates: | ||
161 | bucket.items.append(item) | ||
162 | |||
163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | ||
164 | return buckets | ||
165 | |||
166 | def setup(self): | ||
123 | with open(self.data_file, 'rt') as f: | 167 | with open(self.data_file, 'rt') as f: |
124 | metadata = json.load(f) | 168 | metadata = json.load(f) |
125 | template = metadata[self.template_key] if self.template_key in metadata else {} | 169 | template = metadata[self.template_key] if self.template_key in metadata else {} |
@@ -144,48 +188,48 @@ class CSVDataModule(): | |||
144 | self.data_train = self.pad_items(data_train, self.num_class_images) | 188 | self.data_train = self.pad_items(data_train, self.num_class_images) |
145 | self.data_val = self.pad_items(data_val) | 189 | self.data_val = self.pad_items(data_val) |
146 | 190 | ||
147 | def setup(self, stage=None): | 191 | buckets = self.generate_buckets(data_train) |
148 | train_dataset = CSVDataset( | 192 | |
149 | self.data_train, self.prompt_processor, batch_size=self.batch_size, | 193 | train_datasets = [ |
150 | num_class_images=self.num_class_images, | 194 | VlpnDataset( |
151 | size=self.size, interpolation=self.interpolation, | 195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, |
152 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout | 196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
153 | ) | 197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
154 | val_dataset = CSVDataset( | 198 | ) |
155 | self.data_val, self.prompt_processor, batch_size=self.batch_size, | 199 | for bucket in buckets |
156 | size=self.size, interpolation=self.interpolation, | 200 | ] |
157 | center_crop=self.center_crop | 201 | |
158 | ) | 202 | val_dataset = VlpnDataset( |
159 | self.train_dataloader_ = DataLoader( | 203 | data_val, self.prompt_processor, batch_size=self.batch_size, |
160 | train_dataset, batch_size=self.batch_size, | 204 | width=self.size, height=self.size, interpolation=self.interpolation, |
161 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
162 | num_workers=self.num_workers | ||
163 | ) | ||
164 | self.val_dataloader_ = DataLoader( | ||
165 | val_dataset, batch_size=self.batch_size, | ||
166 | pin_memory=True, collate_fn=self.collate_fn, | ||
167 | num_workers=self.num_workers | ||
168 | ) | 205 | ) |
169 | 206 | ||
170 | def train_dataloader(self): | 207 | self.train_dataloaders = [ |
171 | return self.train_dataloader_ | 208 | DataLoader( |
209 | dataset, batch_size=self.batch_size, shuffle=True, | ||
210 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
211 | ) | ||
212 | for dataset in train_datasets | ||
213 | ] | ||
172 | 214 | ||
173 | def val_dataloader(self): | 215 | self.val_dataloader = DataLoader( |
174 | return self.val_dataloader_ | 216 | val_dataset, batch_size=self.batch_size, |
217 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
218 | ) | ||
175 | 219 | ||
176 | 220 | ||
177 | class CSVDataset(Dataset): | 221 | class VlpnDataset(Dataset): |
178 | def __init__( | 222 | def __init__( |
179 | self, | 223 | self, |
180 | data: List[CSVDataItem], | 224 | data: List[VlpnDataItem], |
181 | prompt_processor: PromptProcessor, | 225 | prompt_processor: PromptProcessor, |
182 | batch_size: int = 1, | 226 | batch_size: int = 1, |
183 | num_class_images: int = 0, | 227 | num_class_images: int = 0, |
184 | size: int = 768, | 228 | width: int = 768, |
229 | height: int = 768, | ||
185 | repeats: int = 1, | 230 | repeats: int = 1, |
186 | dropout: float = 0, | 231 | dropout: float = 0, |
187 | interpolation: str = "bicubic", | 232 | interpolation: str = "bicubic", |
188 | center_crop: bool = False, | ||
189 | ): | 233 | ): |
190 | 234 | ||
191 | self.data = data | 235 | self.data = data |
@@ -193,7 +237,6 @@ class CSVDataset(Dataset): | |||
193 | self.batch_size = batch_size | 237 | self.batch_size = batch_size |
194 | self.num_class_images = num_class_images | 238 | self.num_class_images = num_class_images |
195 | self.dropout = dropout | 239 | self.dropout = dropout |
196 | self.image_cache = {} | ||
197 | 240 | ||
198 | self.num_instance_images = len(self.data) | 241 | self.num_instance_images = len(self.data) |
199 | self._length = self.num_instance_images * repeats | 242 | self._length = self.num_instance_images * repeats |
@@ -206,8 +249,8 @@ class CSVDataset(Dataset): | |||
206 | }[interpolation] | 249 | }[interpolation] |
207 | self.image_transforms = transforms.Compose( | 250 | self.image_transforms = transforms.Compose( |
208 | [ | 251 | [ |
209 | transforms.Resize(size, interpolation=self.interpolation), | 252 | transforms.Resize(min(width, height), interpolation=self.interpolation), |
210 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 253 | transforms.RandomCrop((height, width)), |
211 | transforms.RandomHorizontalFlip(), | 254 | transforms.RandomHorizontalFlip(), |
212 | transforms.ToTensor(), | 255 | transforms.ToTensor(), |
213 | transforms.Normalize([0.5], [0.5]), | 256 | transforms.Normalize([0.5], [0.5]), |
@@ -217,17 +260,6 @@ class CSVDataset(Dataset): | |||
217 | def __len__(self): | 260 | def __len__(self): |
218 | return math.ceil(self._length / self.batch_size) * self.batch_size | 261 | return math.ceil(self._length / self.batch_size) * self.batch_size |
219 | 262 | ||
220 | def get_image(self, path): | ||
221 | if path in self.image_cache: | ||
222 | return self.image_cache[path] | ||
223 | |||
224 | image = Image.open(path) | ||
225 | if not image.mode == "RGB": | ||
226 | image = image.convert("RGB") | ||
227 | self.image_cache[path] = image | ||
228 | |||
229 | return image | ||
230 | |||
231 | def get_example(self, i): | 263 | def get_example(self, i): |
232 | item = self.data[i % self.num_instance_images] | 264 | item = self.data[i % self.num_instance_images] |
233 | 265 | ||
@@ -235,9 +267,9 @@ class CSVDataset(Dataset): | |||
235 | example["prompts"] = item.prompt | 267 | example["prompts"] = item.prompt |
236 | example["cprompts"] = item.cprompt | 268 | example["cprompts"] = item.cprompt |
237 | example["nprompts"] = item.nprompt | 269 | example["nprompts"] = item.nprompt |
238 | example["instance_images"] = self.get_image(item.instance_image_path) | 270 | example["instance_images"] = get_image(item.instance_image_path) |
239 | if self.num_class_images != 0: | 271 | if self.num_class_images != 0: |
240 | example["class_images"] = self.get_image(item.class_image_path) | 272 | example["class_images"] = get_image(item.class_image_path) |
241 | 273 | ||
242 | return example | 274 | return example |
243 | 275 | ||