diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py index 818fcd9..3af9925 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -192,7 +192,7 @@ class VlpnDataModule(): | |||
192 | valid_set_size: Optional[int] = None, | 192 | valid_set_size: Optional[int] = None, |
193 | train_set_pad: Optional[int] = None, | 193 | train_set_pad: Optional[int] = None, |
194 | valid_set_pad: Optional[int] = None, | 194 | valid_set_pad: Optional[int] = None, |
195 | seed: Optional[int] = None, | 195 | generator: Optional[torch.Generator] = None, |
196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
197 | dtype: torch.dtype = torch.float32, | 197 | dtype: torch.dtype = torch.float32, |
198 | ): | 198 | ): |
@@ -224,10 +224,10 @@ class VlpnDataModule(): | |||
224 | self.valid_set_size = valid_set_size | 224 | self.valid_set_size = valid_set_size |
225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size | 225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | 226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size |
227 | self.seed = seed | ||
228 | self.filter = filter | 227 | self.filter = filter |
229 | self.batch_size = batch_size | 228 | self.batch_size = batch_size |
230 | self.dtype = dtype | 229 | self.dtype = dtype |
230 | self.generator = generator | ||
231 | 231 | ||
232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
233 | tpl_image = template["image"] if "image" in template else "{}" | 233 | tpl_image = template["image"] if "image" in template else "{}" |
@@ -304,16 +304,12 @@ class VlpnDataModule(): | |||
304 | train_set_size = max(num_images - valid_set_size, 1) | 304 | train_set_size = max(num_images - valid_set_size, 1) |
305 | valid_set_size = num_images - train_set_size | 305 | valid_set_size = num_images - train_set_size |
306 | 306 | ||
307 | generator = torch.Generator(device="cpu") | ||
308 | if self.seed is not None: | ||
309 | generator = generator.manual_seed(self.seed) | ||
310 | |||
311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 307 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
312 | 308 | ||
313 | if valid_set_size == 0: | 309 | if valid_set_size == 0: |
314 | data_train, data_val = items, items | 310 | data_train, data_val = items, items |
315 | else: | 311 | else: |
316 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 312 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) |
317 | 313 | ||
318 | data_train = self.pad_items(data_train, self.num_class_images) | 314 | data_train = self.pad_items(data_train, self.num_class_images) |
319 | 315 | ||
@@ -324,7 +320,7 @@ class VlpnDataModule(): | |||
324 | data_train, self.tokenizer, | 320 | data_train, self.tokenizer, |
325 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 321 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
326 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 322 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
327 | batch_size=self.batch_size, fill_batch=True, generator=generator, | 323 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, |
328 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 324 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
329 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 325 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
330 | ) | 326 | ) |
@@ -344,7 +340,7 @@ class VlpnDataModule(): | |||
344 | data_val, self.tokenizer, | 340 | data_val, self.tokenizer, |
345 | num_buckets=self.num_buckets, progressive_buckets=True, | 341 | num_buckets=self.num_buckets, progressive_buckets=True, |
346 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 342 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
347 | batch_size=self.batch_size, generator=generator, | 343 | batch_size=self.batch_size, generator=self.generator, |
348 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 344 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
349 | ) | 345 | ) |
350 | 346 | ||