diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -245,6 +245,8 @@ class VlpnDataModule(): | |||
245 | 245 | ||
246 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
247 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
248 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
249 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
248 | repeat=self.valid_set_repeat, | 250 | repeat=self.valid_set_repeat, |
249 | batch_size=self.batch_size, generator=generator, | 251 | batch_size=self.batch_size, generator=generator, |
250 | size=self.size, interpolation=self.interpolation, | 252 | size=self.size, interpolation=self.interpolation, |
@@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): | |||
291 | self.generator = generator | 293 | self.generator = generator |
292 | 294 | ||
293 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 295 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
294 | [item.instance_image_path for item in items], | 296 | [item.instance_image_path for item in self.items], |
295 | base_size=size, | 297 | base_size=size, |
296 | step_size=bucket_step_size, | 298 | step_size=bucket_step_size, |
297 | num_buckets=num_buckets, | 299 | num_buckets=num_buckets, |
@@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): | |||
301 | 303 | ||
302 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 304 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
303 | 305 | ||
304 | self.cache = {} | ||
305 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 306 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
306 | 307 | ||
307 | def __len__(self): | 308 | def __len__(self): |