diff options
Diffstat (limited to 'common.py')
| -rw-r--r-- | common.py | 14 |
1 files changed, 14 insertions, 0 deletions
| @@ -1,9 +1,23 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | ||
| 3 | |||
| 2 | import torch | 4 | import torch |
| 3 | 5 | ||
| 4 | from transformers import CLIPTextModel, CLIPTokenizer | 6 | from transformers import CLIPTextModel, CLIPTokenizer |
| 5 | 7 | ||
| 6 | 8 | ||
| 9 | def load_config(filename): | ||
| 10 | with open(filename, 'rt') as f: | ||
| 11 | config = json.load(f) | ||
| 12 | |||
| 13 | args = config["args"] | ||
| 14 | |||
| 15 | if "base" in config: | ||
| 16 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
| 17 | |||
| 18 | return args | ||
| 19 | |||
| 20 | |||
| 7 | def load_text_embedding(embeddings, token_id, file): | 21 | def load_text_embedding(embeddings, token_id, file): |
| 8 | data = torch.load(file, map_location="cpu") | 22 | data = torch.load(file, map_location="cpu") |
| 9 | 23 | ||
