From 6c64f769043c8212b1a5778e857af691a828798d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 10:19:38 +0100 Subject: Various cleanups --- util.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 util.py (limited to 'util.py') diff --git a/util.py b/util.py new file mode 100644 index 0000000..545bcb5 --- /dev/null +++ b/util.py @@ -0,0 +1,45 @@ +from pathlib import Path +import json + +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer + +from safetensors import safe_open + + +def load_config(filename): + with open(filename, 'rt') as f: + config = json.load(f) + + args = config["args"] + + if "base" in config: + args = load_config(Path(filename).parent.joinpath(config["base"])) | args + + return args + + +def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + return [] + + filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + tokens = [filename.stem for filename in filenames] + + new_ids: list[list[int]] = [] + new_embeds = [] + + for filename in filenames: + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") + + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + new_ids.append(added) + new_embeds.append(embed) + + embeddings.resize(len(tokenizer)) + + for (new_id, embeds) in zip(new_ids, new_embeds): + embeddings.add_embed(new_id, embeds) + + return tokens, new_ids -- cgit v1.2.3-54-g00ecf