diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 4bbc64e..0d8ee23 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -865,6 +865,8 @@ def main(): | |||
| 865 | max_grad_norm=args.max_grad_norm, | 865 | max_grad_norm=args.max_grad_norm, |
| 866 | ) | 866 | ) |
| 867 | 867 | ||
| 868 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
| 869 | |||
| 868 | create_datamodule = partial( | 870 | create_datamodule = partial( |
| 869 | VlpnDataModule, | 871 | VlpnDataModule, |
| 870 | data_file=args.train_data_file, | 872 | data_file=args.train_data_file, |
| @@ -882,8 +884,8 @@ def main(): | |||
| 882 | valid_set_size=args.valid_set_size, | 884 | valid_set_size=args.valid_set_size, |
| 883 | train_set_pad=args.train_set_pad, | 885 | train_set_pad=args.train_set_pad, |
| 884 | valid_set_pad=args.valid_set_pad, | 886 | valid_set_pad=args.valid_set_pad, |
| 885 | seed=args.seed, | ||
| 886 | dtype=weight_dtype, | 887 | dtype=weight_dtype, |
| 888 | generator=data_generator, | ||
| 887 | ) | 889 | ) |
| 888 | 890 | ||
| 889 | create_lr_scheduler = partial( | 891 | create_lr_scheduler = partial( |
