diff options
author | Volpeon <git@volpeon.ink> | 2023-03-27 07:15:46 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-27 07:15:46 +0200 |
commit | 0e4c36889aa6b7ec13320a03728118c7c1a8e716 (patch) | |
tree | 461e63354dac6ab5b68d0f57e1569798df5bf202 /models | |
parent | Fix TI embeddings init (diff) | |
download | textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.gz textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.bz2 textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.zip |
Sparse TI embeddings without sparse tensors
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 870ee49..95904cf 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
42 | self.init_temp_embeddings() | 42 | self.init_temp_embeddings() |
43 | 43 | ||
44 | def init_temp_embeddings(self): | 44 | def init_temp_embeddings(self): |
45 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.ParameterList() |
46 | 0, | ||
47 | self.token_embedding.embedding_dim, | ||
48 | device=self.token_embedding.weight.device, | ||
49 | dtype=self.token_embedding.weight.dtype | ||
50 | ) | ||
51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 46 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
52 | 47 | ||
53 | def resize(self, size: int): | 48 | def resize(self, size: int): |
54 | self.temp_token_embedding = resize_embedding( | 49 | for _ in range(len(self.temp_token_embedding), size): |
55 | self.temp_token_embedding, | 50 | self.temp_token_embedding.append(torch.zeros( |
56 | size - self.num_permanent_embeddings, | 51 | self.token_embedding.embedding_dim, |
57 | self.initializer_factor | 52 | device=self.token_embedding.weight.device, |
58 | ) | 53 | dtype=self.token_embedding.weight.dtype, |
54 | )) | ||
59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 55 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
60 | 56 | ||
61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 57 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
@@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
74 | with torch.no_grad(): | 70 | with torch.no_grad(): |
75 | initializer = self.get_embed(initializer) | 71 | initializer = self.get_embed(initializer) |
76 | 72 | ||
73 | initializer = initializer.to( | ||
74 | device=self.token_embedding.weight.device, | ||
75 | dtype=self.token_embedding.weight.dtype, | ||
76 | ) | ||
77 | |||
77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 78 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
78 | 79 | ||
79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 80 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 81 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 82 | |
82 | device=self.temp_token_embedding.weight.device, | 83 | for i, id in enumerate(mask): |
83 | dtype=self.temp_token_embedding.weight.dtype, | 84 | self.temp_token_embedding[id] = initializer[i] |
84 | ) | ||
85 | 85 | ||
86 | def load_embed(self, input_ids: list[int], filename: Path): | 86 | def load_embed(self, input_ids: list[int], filename: Path): |
87 | with safe_open(filename, framework="pt", device="cpu") as file: | 87 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
91 | save_file({"embed": self.get_embed(input_ids)}, filename) | 91 | save_file({"embed": self.get_embed(input_ids)}, filename) |
92 | 92 | ||
93 | def persist(self): | 93 | def persist(self): |
94 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 94 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): |
95 | self.token_embedding.weight.data[id] = emb | ||
95 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 96 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
96 | self.init_temp_embeddings() | 97 | self.init_temp_embeddings() |
97 | 98 | ||
@@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
110 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | 111 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) |
111 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | 112 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() |
112 | 113 | ||
113 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | 114 | if len(temp_token_ids): |
115 | embeds_override = torch.stack([ | ||
116 | self.temp_token_embedding[id] | ||
117 | for id in temp_token_ids | ||
118 | ]) | ||
119 | embeds[embeds_mask] = embeds_override | ||
114 | 120 | ||
115 | return embeds | 121 | return embeds |
116 | 122 | ||