diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 4bbc64e..0d8ee23 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -865,6 +865,8 @@ def main(): | |||
865 | max_grad_norm=args.max_grad_norm, | 865 | max_grad_norm=args.max_grad_norm, |
866 | ) | 866 | ) |
867 | 867 | ||
868 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
869 | |||
868 | create_datamodule = partial( | 870 | create_datamodule = partial( |
869 | VlpnDataModule, | 871 | VlpnDataModule, |
870 | data_file=args.train_data_file, | 872 | data_file=args.train_data_file, |
@@ -882,8 +884,8 @@ def main(): | |||
882 | valid_set_size=args.valid_set_size, | 884 | valid_set_size=args.valid_set_size, |
883 | train_set_pad=args.train_set_pad, | 885 | train_set_pad=args.train_set_pad, |
884 | valid_set_pad=args.valid_set_pad, | 886 | valid_set_pad=args.valid_set_pad, |
885 | seed=args.seed, | ||
886 | dtype=weight_dtype, | 887 | dtype=weight_dtype, |
888 | generator=data_generator, | ||
887 | ) | 889 | ) |
888 | 890 | ||
889 | create_lr_scheduler = partial( | 891 | create_lr_scheduler = partial( |