summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py33
1 files changed, 22 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py
index 59d6d8d..654aec1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms 7from torchvision import transforms
8from typing import Dict, NamedTuple, List, Optional, Union, Callable 8from typing import Dict, NamedTuple, List, Optional, Union, Callable
9 9
10import numpy as np
11
10from models.clip.prompt import PromptProcessor 12from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt 13from data.keywords import prompt_to_keywords, keywords_to_prompt
12 14
@@ -56,6 +58,8 @@ class VlpnDataModule():
56 class_subdir: str = "cls", 58 class_subdir: str = "cls",
57 num_class_images: int = 1, 59 num_class_images: int = 1,
58 size: int = 768, 60 size: int = 768,
61 num_aspect_ratio_buckets: int = 0,
62 progressive_aspect_ratio_buckets: bool = False,
59 repeats: int = 1, 63 repeats: int = 1,
60 dropout: float = 0, 64 dropout: float = 0,
61 interpolation: str = "bicubic", 65 interpolation: str = "bicubic",
@@ -80,6 +84,8 @@ class VlpnDataModule():
80 84
81 self.prompt_processor = prompt_processor 85 self.prompt_processor = prompt_processor
82 self.size = size 86 self.size = size
87 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets
88 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets
83 self.repeats = repeats 89 self.repeats = repeats
84 self.dropout = dropout 90 self.dropout = dropout
85 self.template_key = template_key 91 self.template_key = template_key
@@ -143,25 +149,32 @@ class VlpnDataModule():
143 def generate_buckets(self, items: list[VlpnDataItem]): 149 def generate_buckets(self, items: list[VlpnDataItem]):
144 buckets = [VlpnDataBucket(self.size, self.size)] 150 buckets = [VlpnDataBucket(self.size, self.size)]
145 151
146 for i in range(1, 5): 152 for i in range(1, self.num_aspect_ratio_buckets + 1):
147 s = self.size + i * 64 153 s = self.size + i * 64
148 buckets.append(VlpnDataBucket(s, self.size)) 154 buckets.append(VlpnDataBucket(s, self.size))
149 buckets.append(VlpnDataBucket(self.size, s)) 155 buckets.append(VlpnDataBucket(self.size, s))
150 156
157 buckets = np.array(buckets)
158 bucket_ratios = np.array([bucket.ratio for bucket in buckets])
159
151 for item in items: 160 for item in items:
152 image = get_image(item.instance_image_path) 161 image = get_image(item.instance_image_path)
153 ratio = image.width / image.height 162 ratio = image.width / image.height
154 163
155 if ratio >= 1: 164 if ratio >= 1:
156 candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] 165 mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio)
157 else: 166 else:
158 candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] 167 mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio)
159 168
160 for bucket in candidates: 169 if not self.progressive_aspect_ratio_buckets:
170 ratios = bucket_ratios.copy()
171 ratios[~mask] = math.inf
172 mask = [np.argmin(np.abs(ratios - ratio))]
173
174 for bucket in buckets[mask]:
161 bucket.items.append(item) 175 bucket.items.append(item)
162 176
163 buckets = [bucket for bucket in buckets if len(bucket.items) != 0] 177 return [bucket for bucket in buckets if len(bucket.items) != 0]
164 return buckets
165 178
166 def setup(self): 179 def setup(self):
167 with open(self.data_file, 'rt') as f: 180 with open(self.data_file, 'rt') as f:
@@ -192,7 +205,7 @@ class VlpnDataModule():
192 205
193 train_datasets = [ 206 train_datasets = [
194 VlpnDataset( 207 VlpnDataset(
195 bucket.items, self.prompt_processor, batch_size=self.batch_size, 208 bucket.items, self.prompt_processor,
196 width=bucket.width, height=bucket.height, interpolation=self.interpolation, 209 width=bucket.width, height=bucket.height, interpolation=self.interpolation,
197 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, 210 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout,
198 ) 211 )
@@ -200,7 +213,7 @@ class VlpnDataModule():
200 ] 213 ]
201 214
202 val_dataset = VlpnDataset( 215 val_dataset = VlpnDataset(
203 data_val, self.prompt_processor, batch_size=self.batch_size, 216 data_val, self.prompt_processor,
204 width=self.size, height=self.size, interpolation=self.interpolation, 217 width=self.size, height=self.size, interpolation=self.interpolation,
205 ) 218 )
206 219
@@ -223,7 +236,6 @@ class VlpnDataset(Dataset):
223 self, 236 self,
224 data: List[VlpnDataItem], 237 data: List[VlpnDataItem],
225 prompt_processor: PromptProcessor, 238 prompt_processor: PromptProcessor,
226 batch_size: int = 1,
227 num_class_images: int = 0, 239 num_class_images: int = 0,
228 width: int = 768, 240 width: int = 768,
229 height: int = 768, 241 height: int = 768,
@@ -234,7 +246,6 @@ class VlpnDataset(Dataset):
234 246
235 self.data = data 247 self.data = data
236 self.prompt_processor = prompt_processor 248 self.prompt_processor = prompt_processor
237 self.batch_size = batch_size
238 self.num_class_images = num_class_images 249 self.num_class_images = num_class_images
239 self.dropout = dropout 250 self.dropout = dropout
240 251
@@ -258,7 +269,7 @@ class VlpnDataset(Dataset):
258 ) 269 )
259 270
260 def __len__(self): 271 def __len__(self):
261 return math.ceil(self._length / self.batch_size) * self.batch_size 272 return self._length
262 273
263 def get_example(self, i): 274 def get_example(self, i):
264 item = self.data[i % self.num_instance_images] 275 item = self.data[i % self.num_instance_images]