summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1280ebd..fb639f1 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
56 init_ratio = 1.0
57
56 if isinstance(token_ids, int): 58 if isinstance(token_ids, int):
57 token_ids = [token_ids] 59 token_ids = [token_ids]
58 60
@@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
63 initializer = [initializer] 65 initializer = [initializer]
64 66
65 if isinstance(initializer, list): 67 if isinstance(initializer, list):
68 init_ratio = len(initializer) / len(token_ids)
66 initializer = (initializer * len(token_ids))[:len(token_ids)] 69 initializer = (initializer * len(token_ids))[:len(token_ids)]
67 70
68 with torch.no_grad(): 71 with torch.no_grad():
@@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 dtype=self.temp_token_embedding.weight.dtype, 79 dtype=self.temp_token_embedding.weight.dtype,
77 ) 80 )
78 81
82 return init_ratio
83
79 def load_embed(self, input_ids: list[int], filename: Path): 84 def load_embed(self, input_ids: list[int], filename: Path):
80 with safe_open(filename, framework="pt", device="cpu") as file: 85 with safe_open(filename, framework="pt", device="cpu") as file:
81 self.add_embed(input_ids, file.get_tensor("embed")) 86 self.add_embed(input_ids, file.get_tensor("embed"))