From 03303d3bddba5a27a123babdf90863e27501e6f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 23:09:25 +0100 Subject: Unified loading of TI embeddings --- common.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 common.py (limited to 'common.py') diff --git a/common.py b/common.py new file mode 100644 index 0000000..8d6b55d --- /dev/null +++ b/common.py @@ -0,0 +1,36 @@ +from pathlib import Path +import torch + +from transformers import CLIPTextModel, CLIPTokenizer + + +def load_text_embedding(embeddings, token_id, file): + data = torch.load(file, map_location="cpu") + + assert len(data.keys()) == 1, 'embedding data has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + + embeddings[token_id] = emb + + +def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + return 0 + + files = [file for file in embeddings_dir.iterdir() if file.is_file()] + + tokens = [file.stem for file in files] + added = tokenizer.add_tokens(tokens) + token_ids = tokenizer.convert_tokens_to_ids(tokens) + + text_encoder.resize_token_embeddings(len(tokenizer)) + + token_embeds = text_encoder.get_input_embeddings().weight.data + + for (token_id, file) in zip(token_ids, files): + load_text_embedding(token_embeds, token_id, file) + + return added -- cgit v1.2.3-54-g00ecf