diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
| commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
| tree | 6ad882f152e63801d31230466e4d6468e7ada697 /common.py | |
| parent | Code simplifications, avoid autocast (diff) | |
| download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip | |
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'common.py')
| -rw-r--r-- | common.py | 14 |
1 files changed, 14 insertions, 0 deletions
| @@ -1,9 +1,23 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | ||
| 3 | |||
| 2 | import torch | 4 | import torch |
| 3 | 5 | ||
| 4 | from transformers import CLIPTextModel, CLIPTokenizer | 6 | from transformers import CLIPTextModel, CLIPTokenizer |
| 5 | 7 | ||
| 6 | 8 | ||
| 9 | def load_config(filename): | ||
| 10 | with open(filename, 'rt') as f: | ||
| 11 | config = json.load(f) | ||
| 12 | |||
| 13 | args = config["args"] | ||
| 14 | |||
| 15 | if "base" in config: | ||
| 16 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
| 17 | |||
| 18 | return args | ||
| 19 | |||
| 20 | |||
| 7 | def load_text_embedding(embeddings, token_id, file): | 21 | def load_text_embedding(embeddings, token_id, file): |
| 8 | data = torch.load(file, map_location="cpu") | 22 | data = torch.load(file, map_location="cpu") |
| 9 | 23 | ||
