diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 91bda5c..d5dde02 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -13,9 +13,11 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
16 | from slugify import slugify | ||
17 | import transformers | 16 | import transformers |
18 | 17 | ||
18 | import numpy as np | ||
19 | from slugify import slugify | ||
20 | |||
19 | from util.files import load_config, load_embeddings_from_dir | 21 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 22 | from data.csv import VlpnDataModule, keyword_filter |
21 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
@@ -873,6 +875,7 @@ def main(): | |||
873 | ) | 875 | ) |
874 | 876 | ||
875 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 877 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
878 | data_npgenerator = np.random.default_rng(args.seed) | ||
876 | 879 | ||
877 | create_datamodule = partial( | 880 | create_datamodule = partial( |
878 | VlpnDataModule, | 881 | VlpnDataModule, |
@@ -893,6 +896,7 @@ def main(): | |||
893 | valid_set_pad=args.valid_set_pad, | 896 | valid_set_pad=args.valid_set_pad, |
894 | dtype=weight_dtype, | 897 | dtype=weight_dtype, |
895 | generator=data_generator, | 898 | generator=data_generator, |
899 | npgenerator=data_npgenerator, | ||
896 | ) | 900 | ) |
897 | 901 | ||
898 | create_lr_scheduler = partial( | 902 | create_lr_scheduler = partial( |