From 7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 9 Jan 2023 10:57:05 +0100 Subject: Enable buckets for validation, fixed vaildation repeat arg --- data/csv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -245,6 +245,8 @@ class VlpnDataModule(): val_dataset = VlpnDataset( self.data_val, self.prompt_processor, + num_buckets=self.num_buckets, progressive_buckets=True, + bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, @@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( - [item.instance_image_path for item in items], + [item.instance_image_path for item in self.items], base_size=size, step_size=bucket_step_size, num_buckets=num_buckets, @@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): self.bucket_item_range = torch.arange(len(self.bucket_items)) - self.cache = {} self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def __len__(self): -- cgit v1.2.3-70-g09d2