summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
commit89afcfda3f824cc44221e877182348f9b09687d2 (patch)
tree804b84322e5caa8fb861322ce6970bef4b532c61 /data
parentExtended Dreambooth: Train TI tokens separately (diff)
downloadtextual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.gz
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.bz2
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.zip
Handle empty validation dataset
Diffstat (limited to 'data')
-rw-r--r--data/csv.py47
1 files changed, 27 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py
index 002fdd2..968af8d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -269,18 +269,22 @@ class VlpnDataModule():
269 269
270 num_images = len(items) 270 num_images = len(items)
271 271
272 valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 272 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
273 valid_set_size = max(valid_set_size, 1) 273 train_set_size = max(num_images - valid_set_size, 1)
274 train_set_size = num_images - valid_set_size 274 valid_set_size = num_images - train_set_size
275 275
276 generator = torch.Generator(device="cpu") 276 generator = torch.Generator(device="cpu")
277 if self.seed is not None: 277 if self.seed is not None:
278 generator = generator.manual_seed(self.seed) 278 generator = generator.manual_seed(self.seed)
279 279
280 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 280 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
281
282 if valid_set_size == 0:
283 data_train, data_val = items, []
284 else:
285 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
281 286
282 self.data_train = self.pad_items(data_train, self.num_class_images) 287 self.data_train = self.pad_items(data_train, self.num_class_images)
283 self.data_val = self.pad_items(data_val)
284 288
285 train_dataset = VlpnDataset( 289 train_dataset = VlpnDataset(
286 self.data_train, self.tokenizer, 290 self.data_train, self.tokenizer,
@@ -291,26 +295,29 @@ class VlpnDataModule():
291 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 295 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
292 ) 296 )
293 297
294 val_dataset = VlpnDataset(
295 self.data_val, self.tokenizer,
296 num_buckets=self.num_buckets, progressive_buckets=True,
297 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
298 repeat=self.valid_set_repeat,
299 batch_size=self.batch_size, generator=generator,
300 size=self.size, interpolation=self.interpolation,
301 )
302
303 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
304
305 self.train_dataloader = DataLoader( 298 self.train_dataloader = DataLoader(
306 train_dataset, 299 train_dataset,
307 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 300 batch_size=None, pin_memory=True, collate_fn=collate_fn_
308 ) 301 )
309 302
310 self.val_dataloader = DataLoader( 303 if valid_set_size != 0:
311 val_dataset, 304 self.data_val = self.pad_items(data_val)
312 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 305
313 ) 306 val_dataset = VlpnDataset(
307 self.data_val, self.tokenizer,
308 num_buckets=self.num_buckets, progressive_buckets=True,
309 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
310 repeat=self.valid_set_repeat,
311 batch_size=self.batch_size, generator=generator,
312 size=self.size, interpolation=self.interpolation,
313 )
314
315 self.val_dataloader = DataLoader(
316 val_dataset,
317 batch_size=None, pin_memory=True, collate_fn=collate_fn_
318 )
319 else:
320 self.val_dataloader = None
314 321
315 322
316class VlpnDataset(IterableDataset): 323class VlpnDataset(IterableDataset):