summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 23:35:11 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 23:35:11 +0100
commitb2db2b6a7c147cdc2901ece92f3918e5b3c47114 (patch)
tree40cb2a3acd86b4f9dab674a3316058ac7ba2504c /models/clip
parentUpdate (diff)
downloadtextual-inversion-diff-b2db2b6a7c147cdc2901ece92f3918e5b3c47114.tar.gz
textual-inversion-diff-b2db2b6a7c147cdc2901ece92f3918e5b3c47114.tar.bz2
textual-inversion-diff-b2db2b6a7c147cdc2901ece92f3918e5b3c47114.zip
Fix
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index cab1515..f90e7c2 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,6 +37,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
37 37
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor
40 41
41 self.temp_token_embedding = nn.Embedding( 42 self.temp_token_embedding = nn.Embedding(
42 self.token_embedding.num_embeddings, 43 self.token_embedding.num_embeddings,
@@ -44,12 +45,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
44 device=self.token_embedding.weight.device, 45 device=self.token_embedding.weight.device,
45 dtype=self.token_embedding.weight.dtype 46 dtype=self.token_embedding.weight.dtype
46 ) 47 )
47 self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) 48 self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02)
48 self.temp_token_ids = torch.tensor([], dtype=torch.long) 49 self.temp_token_ids = torch.tensor([], dtype=torch.long)
49 50
50 def resize(self, size: int): 51 def resize(self, size: int):
51 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) 52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
52 self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
53 54
54 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
55 if isinstance(token_ids, int): 56 if isinstance(token_ids, int):
@@ -63,14 +64,15 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
63 initializer = (initializer * len(token_ids))[:len(token_ids)] 64 initializer = (initializer * len(token_ids))[:len(token_ids)]
64 65
65 with torch.no_grad(): 66 with torch.no_grad():
66 initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) 67 initializer = self.get_embed(initializer)
67 68
68 token_ids = torch.tensor(token_ids, dtype=torch.long) 69 token_ids = torch.tensor(token_ids, dtype=torch.long)
69 70
70 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 71 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
71 72
72 if initializer is not None: 73 if initializer is not None:
73 self.temp_token_embedding.weight.data[token_ids] = initializer 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype)
74 76
75 def load_embed(self, input_ids: list[int], filename: Path): 77 def load_embed(self, input_ids: list[int], filename: Path):
76 with safe_open(filename, framework="pt", device="cpu") as file: 78 with safe_open(filename, framework="pt", device="cpu") as file: