diff options
Diffstat (limited to 'common.py')
-rw-r--r-- | common.py | 14 |
1 files changed, 14 insertions, 0 deletions
@@ -1,9 +1,23 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | ||
3 | |||
2 | import torch | 4 | import torch |
3 | 5 | ||
4 | from transformers import CLIPTextModel, CLIPTokenizer | 6 | from transformers import CLIPTextModel, CLIPTokenizer |
5 | 7 | ||
6 | 8 | ||
9 | def load_config(filename): | ||
10 | with open(filename, 'rt') as f: | ||
11 | config = json.load(f) | ||
12 | |||
13 | args = config["args"] | ||
14 | |||
15 | if "base" in config: | ||
16 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
17 | |||
18 | return args | ||
19 | |||
20 | |||
7 | def load_text_embedding(embeddings, token_id, file): | 21 | def load_text_embedding(embeddings, token_id, file): |
8 | data = torch.load(file, map_location="cpu") | 22 | data = torch.load(file, map_location="cpu") |
9 | 23 | ||