from pathlib import Path import torch from models.clip.embeddings import ManagedCLIPTextEmbeddings from models.clip.tokenizer import MultiCLIPTokenizer def load_embeddings( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, tokens: list[str], token_embeddings: torch.FloatTensor, ): num_vectors = [embedding.shape[0] for embedding in token_embeddings] token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) embeddings.resize(len(tokenizer)) for (new_id, embeds) in zip(token_ids, token_embeddings): embeddings.add_embed(new_id, embeds) return tokens, token_ids