summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py14
-rw-r--r--train_lora.py4
-rw-r--r--train_ti.py6
3 files changed, 12 insertions, 12 deletions
diff --git a/data/csv.py b/data/csv.py
index 818fcd9..3af9925 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -192,7 +192,7 @@ class VlpnDataModule():
192 valid_set_size: Optional[int] = None, 192 valid_set_size: Optional[int] = None,
193 train_set_pad: Optional[int] = None, 193 train_set_pad: Optional[int] = None,
194 valid_set_pad: Optional[int] = None, 194 valid_set_pad: Optional[int] = None,
195 seed: Optional[int] = None, 195 generator: Optional[torch.Generator] = None,
196 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 196 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 dtype: torch.dtype = torch.float32, 197 dtype: torch.dtype = torch.float32,
198 ): 198 ):
@@ -224,10 +224,10 @@ class VlpnDataModule():
224 self.valid_set_size = valid_set_size 224 self.valid_set_size = valid_set_size
225 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size 225 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size
226 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size 226 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size
227 self.seed = seed
228 self.filter = filter 227 self.filter = filter
229 self.batch_size = batch_size 228 self.batch_size = batch_size
230 self.dtype = dtype 229 self.dtype = dtype
230 self.generator = generator
231 231
232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
233 tpl_image = template["image"] if "image" in template else "{}" 233 tpl_image = template["image"] if "image" in template else "{}"
@@ -304,16 +304,12 @@ class VlpnDataModule():
304 train_set_size = max(num_images - valid_set_size, 1) 304 train_set_size = max(num_images - valid_set_size, 1)
305 valid_set_size = num_images - train_set_size 305 valid_set_size = num_images - train_set_size
306 306
307 generator = torch.Generator(device="cpu")
308 if self.seed is not None:
309 generator = generator.manual_seed(self.seed)
310
311 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) 307 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0)
312 308
313 if valid_set_size == 0: 309 if valid_set_size == 0:
314 data_train, data_val = items, items 310 data_train, data_val = items, items
315 else: 311 else:
316 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 312 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator)
317 313
318 data_train = self.pad_items(data_train, self.num_class_images) 314 data_train = self.pad_items(data_train, self.num_class_images)
319 315
@@ -324,7 +320,7 @@ class VlpnDataModule():
324 data_train, self.tokenizer, 320 data_train, self.tokenizer,
325 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 321 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
326 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 322 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
327 batch_size=self.batch_size, fill_batch=True, generator=generator, 323 batch_size=self.batch_size, fill_batch=True, generator=self.generator,
328 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 324 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
329 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 325 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
330 ) 326 )
@@ -344,7 +340,7 @@ class VlpnDataModule():
344 data_val, self.tokenizer, 340 data_val, self.tokenizer,
345 num_buckets=self.num_buckets, progressive_buckets=True, 341 num_buckets=self.num_buckets, progressive_buckets=True,
346 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 342 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
347 batch_size=self.batch_size, generator=generator, 343 batch_size=self.batch_size, generator=self.generator,
348 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 344 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
349 ) 345 )
350 346
diff --git a/train_lora.py b/train_lora.py
index 4bbc64e..0d8ee23 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -865,6 +865,8 @@ def main():
865 max_grad_norm=args.max_grad_norm, 865 max_grad_norm=args.max_grad_norm,
866 ) 866 )
867 867
868 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
869
868 create_datamodule = partial( 870 create_datamodule = partial(
869 VlpnDataModule, 871 VlpnDataModule,
870 data_file=args.train_data_file, 872 data_file=args.train_data_file,
@@ -882,8 +884,8 @@ def main():
882 valid_set_size=args.valid_set_size, 884 valid_set_size=args.valid_set_size,
883 train_set_pad=args.train_set_pad, 885 train_set_pad=args.train_set_pad,
884 valid_set_pad=args.valid_set_pad, 886 valid_set_pad=args.valid_set_pad,
885 seed=args.seed,
886 dtype=weight_dtype, 887 dtype=weight_dtype,
888 generator=data_generator,
887 ) 889 )
888 890
889 create_lr_scheduler = partial( 891 create_lr_scheduler = partial(
diff --git a/train_ti.py b/train_ti.py
index eb08bda..009495b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -817,6 +817,8 @@ def main():
817 sample_image_size=args.sample_image_size, 817 sample_image_size=args.sample_image_size,
818 ) 818 )
819 819
820 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
821
820 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 822 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
821 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 823 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
822 tokenizer=tokenizer, 824 tokenizer=tokenizer,
@@ -855,9 +857,9 @@ def main():
855 valid_set_size=args.valid_set_size, 857 valid_set_size=args.valid_set_size,
856 train_set_pad=args.train_set_pad, 858 train_set_pad=args.train_set_pad,
857 valid_set_pad=args.valid_set_pad, 859 valid_set_pad=args.valid_set_pad,
858 seed=args.seed,
859 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 860 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
860 dtype=weight_dtype 861 dtype=weight_dtype,
862 generator=data_generator,
861 ) 863 )
862 datamodule.setup() 864 datamodule.setup()
863 865