diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index 59d6d8d..654aec1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split | |||
7 | from torchvision import transforms | 7 | from torchvision import transforms |
8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
9 | 9 | ||
10 | import numpy as np | ||
11 | |||
10 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
12 | 14 | ||
@@ -56,6 +58,8 @@ class VlpnDataModule(): | |||
56 | class_subdir: str = "cls", | 58 | class_subdir: str = "cls", |
57 | num_class_images: int = 1, | 59 | num_class_images: int = 1, |
58 | size: int = 768, | 60 | size: int = 768, |
61 | num_aspect_ratio_buckets: int = 0, | ||
62 | progressive_aspect_ratio_buckets: bool = False, | ||
59 | repeats: int = 1, | 63 | repeats: int = 1, |
60 | dropout: float = 0, | 64 | dropout: float = 0, |
61 | interpolation: str = "bicubic", | 65 | interpolation: str = "bicubic", |
@@ -80,6 +84,8 @@ class VlpnDataModule(): | |||
80 | 84 | ||
81 | self.prompt_processor = prompt_processor | 85 | self.prompt_processor = prompt_processor |
82 | self.size = size | 86 | self.size = size |
87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | ||
88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | ||
83 | self.repeats = repeats | 89 | self.repeats = repeats |
84 | self.dropout = dropout | 90 | self.dropout = dropout |
85 | self.template_key = template_key | 91 | self.template_key = template_key |
@@ -143,25 +149,32 @@ class VlpnDataModule(): | |||
143 | def generate_buckets(self, items: list[VlpnDataItem]): | 149 | def generate_buckets(self, items: list[VlpnDataItem]): |
144 | buckets = [VlpnDataBucket(self.size, self.size)] | 150 | buckets = [VlpnDataBucket(self.size, self.size)] |
145 | 151 | ||
146 | for i in range(1, 5): | 152 | for i in range(1, self.num_aspect_ratio_buckets + 1): |
147 | s = self.size + i * 64 | 153 | s = self.size + i * 64 |
148 | buckets.append(VlpnDataBucket(s, self.size)) | 154 | buckets.append(VlpnDataBucket(s, self.size)) |
149 | buckets.append(VlpnDataBucket(self.size, s)) | 155 | buckets.append(VlpnDataBucket(self.size, s)) |
150 | 156 | ||
157 | buckets = np.array(buckets) | ||
158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
159 | |||
151 | for item in items: | 160 | for item in items: |
152 | image = get_image(item.instance_image_path) | 161 | image = get_image(item.instance_image_path) |
153 | ratio = image.width / image.height | 162 | ratio = image.width / image.height |
154 | 163 | ||
155 | if ratio >= 1: | 164 | if ratio >= 1: |
156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | 165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) |
157 | else: | 166 | else: |
158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | 167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) |
159 | 168 | ||
160 | for bucket in candidates: | 169 | if not self.progressive_aspect_ratio_buckets: |
170 | ratios = bucket_ratios.copy() | ||
171 | ratios[~mask] = math.inf | ||
172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
173 | |||
174 | for bucket in buckets[mask]: | ||
161 | bucket.items.append(item) | 175 | bucket.items.append(item) |
162 | 176 | ||
163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | 177 | return [bucket for bucket in buckets if len(bucket.items) != 0] |
164 | return buckets | ||
165 | 178 | ||
166 | def setup(self): | 179 | def setup(self): |
167 | with open(self.data_file, 'rt') as f: | 180 | with open(self.data_file, 'rt') as f: |
@@ -192,7 +205,7 @@ class VlpnDataModule(): | |||
192 | 205 | ||
193 | train_datasets = [ | 206 | train_datasets = [ |
194 | VlpnDataset( | 207 | VlpnDataset( |
195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, | 208 | bucket.items, self.prompt_processor, |
196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
198 | ) | 211 | ) |
@@ -200,7 +213,7 @@ class VlpnDataModule(): | |||
200 | ] | 213 | ] |
201 | 214 | ||
202 | val_dataset = VlpnDataset( | 215 | val_dataset = VlpnDataset( |
203 | data_val, self.prompt_processor, batch_size=self.batch_size, | 216 | data_val, self.prompt_processor, |
204 | width=self.size, height=self.size, interpolation=self.interpolation, | 217 | width=self.size, height=self.size, interpolation=self.interpolation, |
205 | ) | 218 | ) |
206 | 219 | ||
@@ -223,7 +236,6 @@ class VlpnDataset(Dataset): | |||
223 | self, | 236 | self, |
224 | data: List[VlpnDataItem], | 237 | data: List[VlpnDataItem], |
225 | prompt_processor: PromptProcessor, | 238 | prompt_processor: PromptProcessor, |
226 | batch_size: int = 1, | ||
227 | num_class_images: int = 0, | 239 | num_class_images: int = 0, |
228 | width: int = 768, | 240 | width: int = 768, |
229 | height: int = 768, | 241 | height: int = 768, |
@@ -234,7 +246,6 @@ class VlpnDataset(Dataset): | |||
234 | 246 | ||
235 | self.data = data | 247 | self.data = data |
236 | self.prompt_processor = prompt_processor | 248 | self.prompt_processor = prompt_processor |
237 | self.batch_size = batch_size | ||
238 | self.num_class_images = num_class_images | 249 | self.num_class_images = num_class_images |
239 | self.dropout = dropout | 250 | self.dropout = dropout |
240 | 251 | ||
@@ -258,7 +269,7 @@ class VlpnDataset(Dataset): | |||
258 | ) | 269 | ) |
259 | 270 | ||
260 | def __len__(self): | 271 | def __len__(self): |
261 | return math.ceil(self._length / self.batch_size) * self.batch_size | 272 | return self._length |
262 | 273 | ||
263 | def get_example(self, i): | 274 | def get_example(self, i): |
264 | item = self.data[i % self.num_instance_images] | 275 | item = self.data[i % self.num_instance_images] |