summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f90e7c2..9c3a56b 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
56 if isinstance(token_ids, int): 56 if isinstance(token_ids, int):
57 token_ids = [token_ids] 57 token_ids = [token_ids]
58 58
59 if initializer is not None: 59 if initializer is None:
60 if isinstance(initializer, int): 60 initializer = token_ids
61 initializer = [initializer]
62 61
63 if isinstance(initializer, list): 62 if isinstance(initializer, int):
64 initializer = (initializer * len(token_ids))[:len(token_ids)] 63 initializer = [initializer]
65 64
66 with torch.no_grad(): 65 if isinstance(initializer, list):
67 initializer = self.get_embed(initializer) 66 initializer = (initializer * len(token_ids))[:len(token_ids)]
67
68 with torch.no_grad():
69 initializer = self.get_embed(initializer)
68 70
69 token_ids = torch.tensor(token_ids, dtype=torch.long) 71 token_ids = torch.tensor(token_ids, dtype=torch.long)
70 72
71 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
72 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
73 if initializer is not None: 75 dtype=self.temp_token_embedding.weight.dtype)
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype)
76 76
77 def load_embed(self, input_ids: list[int], filename: Path): 77 def load_embed(self, input_ids: list[int], filename: Path):
78 with safe_open(filename, framework="pt", device="cpu") as file: 78 with safe_open(filename, framework="pt", device="cpu") as file: