summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 6c57f4b..7f5fb49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,10 +12,12 @@ import torch.utils.checkpoint
12from accelerate import Accelerator 12from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from slugify import slugify
16from timm.models import create_model 15from timm.models import create_model
17import transformers 16import transformers
18 17
18import numpy as np
19from slugify import slugify
20
19from util.files import load_config, load_embeddings_from_dir 21from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 22from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
@@ -852,6 +854,7 @@ def main():
852 ) 854 )
853 855
854 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 856 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
857 data_npgenerator = np.random.default_rng(args.seed)
855 858
856 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 859 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
857 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 860 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
@@ -894,6 +897,7 @@ def main():
894 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 897 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
895 dtype=weight_dtype, 898 dtype=weight_dtype,
896 generator=data_generator, 899 generator=data_generator,
900 npgenerator=data_npgenerator,
897 ) 901 )
898 datamodule.setup() 902 datamodule.setup()
899 903