diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 |
| commit | 86e908656bcd7585ec45cd930176800f759f146a (patch) | |
| tree | 1169e9b1728e4c6fc8b70e46a37080ae0794ada8 /models | |
| parent | Experimental: TI via LoRA (diff) | |
| download | textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.gz textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.bz2 textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.zip | |
Combined TI with embedding and LoRA
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 66 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
| 67 | 67 | ||
| 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
| 69 | self.temp_token_embedding = nn.Embedding( | ||
| 70 | self.token_embedding.num_embeddings, | ||
| 71 | self.token_embedding.embedding_dim, | ||
| 72 | device=self.token_embedding.weight.device, | ||
| 73 | dtype=self.token_embedding.weight.dtype | ||
| 74 | ) | ||
| 75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 70 | 77 | ||
| 71 | def reset_overlay(self): | 78 | def reset_overlay(self): |
| 72 | self.overlay.reset() | 79 | self.overlay.reset() |
| 73 | 80 | ||
| 74 | def resize(self, size: int): | 81 | def resize(self, size: int): |
| 82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
| 75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 76 | 84 | ||
| 77 | def add_embed( | 85 | def add_embed( |
| @@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 106 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 114 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 107 | 115 | ||
| 108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 116 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 117 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 109 | self.token_embedding.weight.data[token_ids] = initializer | 118 | self.token_embedding.weight.data[token_ids] = initializer |
| 110 | 119 | ||
| 111 | def load_embed(self, input_ids: list[int], filename: Path): | 120 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 116 | save_file({"embed": self.get_embed(input_ids)}, filename) | 125 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 117 | 126 | ||
| 118 | def persist(self): | 127 | def persist(self): |
| 119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( | 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 120 | self.token_embedding.weight.data[self.temp_token_ids] | 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) |
| 121 | ) | ||
| 122 | self.overlay.reset() | 130 | self.overlay.reset() |
| 123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 124 | 132 | ||
| @@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 128 | 136 | ||
| 129 | embeds = self.token_embedding(input_ids) | 137 | embeds = self.token_embedding(input_ids) |
| 138 | |||
| 130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 131 | embeds[mask] += self.overlay(embeds[mask]) | 140 | |
| 141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
| 142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
| 132 | 143 | ||
| 133 | return embeds | 144 | return embeds |
| 134 | 145 | ||
