diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -245,6 +245,8 @@ class VlpnDataModule(): | |||
| 245 | 245 | ||
| 246 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
| 247 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
| 248 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 249 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 248 | repeat=self.valid_set_repeat, | 250 | repeat=self.valid_set_repeat, |
| 249 | batch_size=self.batch_size, generator=generator, | 251 | batch_size=self.batch_size, generator=generator, |
| 250 | size=self.size, interpolation=self.interpolation, | 252 | size=self.size, interpolation=self.interpolation, |
| @@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): | |||
| 291 | self.generator = generator | 293 | self.generator = generator |
| 292 | 294 | ||
| 293 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 295 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
| 294 | [item.instance_image_path for item in items], | 296 | [item.instance_image_path for item in self.items], |
| 295 | base_size=size, | 297 | base_size=size, |
| 296 | step_size=bucket_step_size, | 298 | step_size=bucket_step_size, |
| 297 | num_buckets=num_buckets, | 299 | num_buckets=num_buckets, |
| @@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): | |||
| 301 | 303 | ||
| 302 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 304 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
| 303 | 305 | ||
| 304 | self.cache = {} | ||
| 305 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 306 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
| 306 | 307 | ||
| 307 | def __len__(self): | 308 | def __len__(self): |
