summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index 4bbc64e..0d8ee23 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -865,6 +865,8 @@ def main():
865 max_grad_norm=args.max_grad_norm, 865 max_grad_norm=args.max_grad_norm,
866 ) 866 )
867 867
868 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
869
868 create_datamodule = partial( 870 create_datamodule = partial(
869 VlpnDataModule, 871 VlpnDataModule,
870 data_file=args.train_data_file, 872 data_file=args.train_data_file,
@@ -882,8 +884,8 @@ def main():
882 valid_set_size=args.valid_set_size, 884 valid_set_size=args.valid_set_size,
883 train_set_pad=args.train_set_pad, 885 train_set_pad=args.train_set_pad,
884 valid_set_pad=args.valid_set_pad, 886 valid_set_pad=args.valid_set_pad,
885 seed=args.seed,
886 dtype=weight_dtype, 887 dtype=weight_dtype,
888 generator=data_generator,
887 ) 889 )
888 890
889 create_lr_scheduler = partial( 891 create_lr_scheduler = partial(