summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index dec66d7..85b98f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -174,7 +174,8 @@ class VlpnDataModule():
174 interpolation: str = "bicubic", 174 interpolation: str = "bicubic",
175 template_key: str = "template", 175 template_key: str = "template",
176 valid_set_size: Optional[int] = None, 176 valid_set_size: Optional[int] = None,
177 valid_set_repeat: int = 1, 177 train_set_pad: Optional[int] = None,
178 valid_set_pad: Optional[int] = None,
178 seed: Optional[int] = None, 179 seed: Optional[int] = None,
179 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 180 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
180 dtype: torch.dtype = torch.float32, 181 dtype: torch.dtype = torch.float32,
@@ -202,7 +203,8 @@ class VlpnDataModule():
202 self.template_key = template_key 203 self.template_key = template_key
203 self.interpolation = interpolation 204 self.interpolation = interpolation
204 self.valid_set_size = valid_set_size 205 self.valid_set_size = valid_set_size
205 self.valid_set_repeat = valid_set_repeat 206 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size
207 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size
206 self.seed = seed 208 self.seed = seed
207 self.filter = filter 209 self.filter = filter
208 self.batch_size = batch_size 210 self.batch_size = batch_size
@@ -267,9 +269,6 @@ class VlpnDataModule():
267 items = self.prepare_items(template, expansions, items) 269 items = self.prepare_items(template, expansions, items)
268 items = self.filter_items(items) 270 items = self.filter_items(items)
269 271
270 if (len(items) < self.batch_size):
271 items = (items * self.batch_size)[:self.batch_size]
272
273 num_images = len(items) 272 num_images = len(items)
274 273
275 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 274 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
@@ -283,14 +282,17 @@ class VlpnDataModule():
283 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
284 283
285 if valid_set_size == 0: 284 if valid_set_size == 0:
286 data_train, data_val = items, [] 285 data_train, data_val = items, items[:1]
287 else: 286 else:
288 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
289 288
290 self.data_train = self.pad_items(data_train, self.num_class_images) 289 data_train = self.pad_items(data_train, self.num_class_images)
290
291 if len(data_train) < self.train_set_pad:
292 data_train *= math.ceil(self.train_set_pad / len(data_train))
291 293
292 train_dataset = VlpnDataset( 294 self.train_dataset = VlpnDataset(
293 self.data_train, self.tokenizer, 295 data_train, self.tokenizer,
294 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 296 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
295 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 297 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
296 batch_size=self.batch_size, generator=generator, 298 batch_size=self.batch_size, generator=generator,
@@ -299,24 +301,26 @@ class VlpnDataModule():
299 ) 301 )
300 302
301 self.train_dataloader = DataLoader( 303 self.train_dataloader = DataLoader(
302 train_dataset, 304 self.train_dataset,
303 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 305 batch_size=None, pin_memory=True, collate_fn=collate_fn_
304 ) 306 )
305 307
306 if valid_set_size != 0: 308 if len(data_val) != 0:
307 self.data_val = self.pad_items(data_val) 309 data_val = self.pad_items(data_val)
310
311 if len(data_val) < self.valid_set_pad:
312 data_val *= math.ceil(self.valid_set_pad / len(data_val))
308 313
309 val_dataset = VlpnDataset( 314 self.val_dataset = VlpnDataset(
310 self.data_val, self.tokenizer, 315 data_val, self.tokenizer,
311 num_buckets=self.num_buckets, progressive_buckets=True, 316 num_buckets=self.num_buckets, progressive_buckets=True,
312 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 317 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
313 repeat=self.valid_set_repeat,
314 batch_size=self.batch_size, generator=generator, 318 batch_size=self.batch_size, generator=generator,
315 size=self.size, interpolation=self.interpolation, 319 size=self.size, interpolation=self.interpolation,
316 ) 320 )
317 321
318 self.val_dataloader = DataLoader( 322 self.val_dataloader = DataLoader(
319 val_dataset, 323 self.val_dataset,
320 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 324 batch_size=None, pin_memory=True, collate_fn=collate_fn_
321 ) 325 )
322 else: 326 else:
@@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset):
332 bucket_step_size: int = 64, 336 bucket_step_size: int = 64,
333 bucket_max_pixels: Optional[int] = None, 337 bucket_max_pixels: Optional[int] = None,
334 progressive_buckets: bool = False, 338 progressive_buckets: bool = False,
335 repeat: int = 1,
336 batch_size: int = 1, 339 batch_size: int = 1,
337 num_class_images: int = 0, 340 num_class_images: int = 0,
338 size: int = 768, 341 size: int = 768,
@@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset):
341 interpolation: str = "bicubic", 344 interpolation: str = "bicubic",
342 generator: Optional[torch.Generator] = None, 345 generator: Optional[torch.Generator] = None,
343 ): 346 ):
344 self.items = items * repeat 347 self.items = items
345 self.batch_size = batch_size 348 self.batch_size = batch_size
346 349
347 self.tokenizer = tokenizer 350 self.tokenizer = tokenizer