diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
| commit | 208e48134e324e934ad964bdc61880cc923f4c0d (patch) | |
| tree | c215f6c201c04b0b2d18ba0df230fb4c5e622985 /models/clip | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.gz textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.bz2 textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.zip | |
Revert
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 42 |
1 files changed, 4 insertions, 38 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index c9c788c..1e21965 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 31 | return new_embedding | 31 | return new_embedding |
| 32 | 32 | ||
| 33 | 33 | ||
| 34 | class OverlayLinear(nn.Module): | ||
| 35 | def __init__(self, in_features, out_features, rank=4): | ||
| 36 | super().__init__() | ||
| 37 | |||
| 38 | if rank > min(in_features, out_features): | ||
| 39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
| 40 | |||
| 41 | self.rank = rank | ||
| 42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
| 43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
| 44 | self.reset() | ||
| 45 | |||
| 46 | def reset(self): | ||
| 47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
| 48 | nn.init.zeros_(self.up.weight) | ||
| 49 | |||
| 50 | def forward(self, hidden_states): | ||
| 51 | orig_dtype = hidden_states.dtype | ||
| 52 | dtype = self.down.weight.dtype | ||
| 53 | |||
| 54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
| 55 | up_hidden_states = self.up(down_hidden_states) | ||
| 56 | |||
| 57 | return up_hidden_states.to(orig_dtype) | ||
| 58 | |||
| 59 | |||
| 60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): | 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): |
| 62 | super().__init__(config) | 36 | super().__init__(config) |
| 63 | 37 | ||
| 64 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
| 65 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
| 66 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
| 41 | self.alpha = alpha | ||
| 67 | 42 | ||
| 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | ||
| 69 | self.temp_token_embedding = nn.Embedding( | 43 | self.temp_token_embedding = nn.Embedding( |
| 70 | self.token_embedding.num_embeddings, | 44 | self.token_embedding.num_embeddings, |
| 71 | self.token_embedding.embedding_dim, | 45 | self.token_embedding.embedding_dim, |
| @@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() |
| 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 77 | 51 | ||
| 78 | def reset_overlay(self): | ||
| 79 | self.overlay.reset() | ||
| 80 | |||
| 81 | def resize(self, size: int): | 52 | def resize(self, size: int): |
| 82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 53 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
| 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 54 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| @@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 125 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 126 | 97 | ||
| 127 | def persist(self): | 98 | def persist(self): |
| 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] | 99 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) | ||
| 130 | self.overlay.reset() | ||
| 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 100 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 132 | 101 | ||
| 133 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 102 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| @@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 104 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 136 | 105 | ||
| 137 | embeds = self.token_embedding(input_ids) | 106 | embeds = self.token_embedding(input_ids) |
| 138 | |||
| 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 140 | 108 | embeds[mask] = self.temp_token_embedding(input_ids[mask]) | |
| 141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
| 142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
| 143 | 109 | ||
| 144 | return embeds | 110 | return embeds |
| 145 | 111 | ||
