From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- data/csv.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index dec66d7..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -174,7 +174,8 @@ class VlpnDataModule(): interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, - valid_set_repeat: int = 1, + train_set_pad: Optional[int] = None, + valid_set_pad: Optional[int] = None, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, @@ -202,7 +203,8 @@ class VlpnDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size - self.valid_set_repeat = valid_set_repeat + self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size + self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size self.seed = seed self.filter = filter self.batch_size = batch_size @@ -267,9 +269,6 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) - if (len(items) < self.batch_size): - items = (items * self.batch_size)[:self.batch_size] - num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 @@ -283,14 +282,17 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, [] + data_train, data_val = items, items[:1] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) - self.data_train = self.pad_items(data_train, self.num_class_images) + data_train = self.pad_items(data_train, self.num_class_images) + + if len(data_train) < self.train_set_pad: + data_train *= math.ceil(self.train_set_pad / len(data_train)) - train_dataset = VlpnDataset( - self.data_train, self.tokenizer, + self.train_dataset = VlpnDataset( + data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, @@ -299,24 +301,26 @@ class VlpnDataModule(): ) self.train_dataloader = DataLoader( - train_dataset, + self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) - if valid_set_size != 0: - self.data_val = self.pad_items(data_val) + if len(data_val) != 0: + data_val = self.pad_items(data_val) + + if len(data_val) < self.valid_set_pad: + data_val *= math.ceil(self.valid_set_pad / len(data_val)) - val_dataset = VlpnDataset( - self.data_val, self.tokenizer, + self.val_dataset = VlpnDataset( + data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) self.val_dataloader = DataLoader( - val_dataset, + self.val_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) else: @@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset): bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, - repeat: int = 1, batch_size: int = 1, num_class_images: int = 0, size: int = 768, @@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): - self.items = items * repeat + self.items = items self.batch_size = batch_size self.tokenizer = tokenizer -- cgit v1.2.3-70-g09d2