diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 6c57f4b..7f5fb49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -12,10 +12,12 @@ import torch.utils.checkpoint | |||
12 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | ||
16 | from timm.models import create_model | 15 | from timm.models import create_model |
17 | import transformers | 16 | import transformers |
18 | 17 | ||
18 | import numpy as np | ||
19 | from slugify import slugify | ||
20 | |||
19 | from util.files import load_config, load_embeddings_from_dir | 21 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 22 | from data.csv import VlpnDataModule, keyword_filter |
21 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
@@ -852,6 +854,7 @@ def main(): | |||
852 | ) | 854 | ) |
853 | 855 | ||
854 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 856 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
857 | data_npgenerator = np.random.default_rng(args.seed) | ||
855 | 858 | ||
856 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 859 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
857 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 860 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
@@ -894,6 +897,7 @@ def main(): | |||
894 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 897 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
895 | dtype=weight_dtype, | 898 | dtype=weight_dtype, |
896 | generator=data_generator, | 899 | generator=data_generator, |
900 | npgenerator=data_npgenerator, | ||
897 | ) | 901 | ) |
898 | datamodule.setup() | 902 | datamodule.setup() |
899 | 903 | ||