summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py
index 818fcd9..3af9925 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -192,7 +192,7 @@ class VlpnDataModule():
192 valid_set_size: Optional[int] = None, 192 valid_set_size: Optional[int] = None,
193 train_set_pad: Optional[int] = None, 193 train_set_pad: Optional[int] = None,
194 valid_set_pad: Optional[int] = None, 194 valid_set_pad: Optional[int] = None,
195 seed: Optional[int] = None, 195 generator: Optional[torch.Generator] = None,
196 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 196 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 dtype: torch.dtype = torch.float32, 197 dtype: torch.dtype = torch.float32,
198 ): 198 ):
@@ -224,10 +224,10 @@ class VlpnDataModule():
224 self.valid_set_size = valid_set_size 224 self.valid_set_size = valid_set_size
225 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size 225 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size
226 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size 226 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size
227 self.seed = seed
228 self.filter = filter 227 self.filter = filter
229 self.batch_size = batch_size 228 self.batch_size = batch_size
230 self.dtype = dtype 229 self.dtype = dtype
230 self.generator = generator
231 231
232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
233 tpl_image = template["image"] if "image" in template else "{}" 233 tpl_image = template["image"] if "image" in template else "{}"
@@ -304,16 +304,12 @@ class VlpnDataModule():
304 train_set_size = max(num_images - valid_set_size, 1) 304 train_set_size = max(num_images - valid_set_size, 1)
305 valid_set_size = num_images - train_set_size 305 valid_set_size = num_images - train_set_size
306 306
307 generator = torch.Generator(device="cpu")
308 if self.seed is not None:
309 generator = generator.manual_seed(self.seed)
310
311 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) 307 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0)
312 308
313 if valid_set_size == 0: 309 if valid_set_size == 0:
314 data_train, data_val = items, items 310 data_train, data_val = items, items
315 else: 311 else:
316 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 312 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator)
317 313
318 data_train = self.pad_items(data_train, self.num_class_images) 314 data_train = self.pad_items(data_train, self.num_class_images)
319 315
@@ -324,7 +320,7 @@ class VlpnDataModule():
324 data_train, self.tokenizer, 320 data_train, self.tokenizer,
325 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 321 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
326 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 322 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
327 batch_size=self.batch_size, fill_batch=True, generator=generator, 323 batch_size=self.batch_size, fill_batch=True, generator=self.generator,
328 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 324 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
329 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 325 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
330 ) 326 )
@@ -344,7 +340,7 @@ class VlpnDataModule():
344 data_val, self.tokenizer, 340 data_val, self.tokenizer,
345 num_buckets=self.num_buckets, progressive_buckets=True, 341 num_buckets=self.num_buckets, progressive_buckets=True,
346 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 342 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
347 batch_size=self.batch_size, generator=generator, 343 batch_size=self.batch_size, generator=self.generator,
348 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 344 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
349 ) 345 )
350 346