diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 16:30:36 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 16:30:36 +0200 |
commit | c96073646bbb638d7d78fdd7d9fdeed08d1454b5 (patch) | |
tree | 3e0846964fa127844d652e2dee081cd67e672e6a /models/clip | |
parent | Update (diff) | |
download | textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.gz textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.bz2 textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.zip |
Experimental: TI via LoRA
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 53 |
1 files changed, 38 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9abd1bb..88e0cc0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
31 | return new_embedding | 31 | return new_embedding |
32 | 32 | ||
33 | 33 | ||
34 | class OverlayLinear(nn.Module): | ||
35 | def __init__(self, in_features, out_features, rank=4): | ||
36 | super().__init__() | ||
37 | |||
38 | if rank > min(in_features, out_features): | ||
39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
40 | |||
41 | self.rank = rank | ||
42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
44 | self.reset() | ||
45 | |||
46 | def reset(self): | ||
47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
48 | nn.init.zeros_(self.up.weight) | ||
49 | |||
50 | def forward(self, hidden_states): | ||
51 | orig_dtype = hidden_states.dtype | ||
52 | dtype = self.down.weight.dtype | ||
53 | |||
54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
55 | up_hidden_states = self.up(down_hidden_states) | ||
56 | |||
57 | return up_hidden_states.to(orig_dtype) | ||
58 | |||
59 | |||
34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): |
36 | super().__init__(config) | 62 | super().__init__(config) |
37 | 63 | ||
38 | self.token_embedding = embeddings.token_embedding | 64 | self.token_embedding = embeddings.token_embedding |
39 | self.position_embedding = embeddings.position_embedding | 65 | self.position_embedding = embeddings.position_embedding |
40 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
41 | 67 | ||
42 | self.temp_token_embedding = nn.Embedding( | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
43 | self.token_embedding.num_embeddings, | ||
44 | self.token_embedding.embedding_dim, | ||
45 | device=self.token_embedding.weight.device, | ||
46 | dtype=self.token_embedding.weight.dtype | ||
47 | ) | ||
48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
50 | 70 | ||
71 | def reset_overlay(self): | ||
72 | self.overlay.reset() | ||
73 | |||
51 | def resize(self, size: int): | 74 | def resize(self, size: int): |
52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 76 | ||
55 | def add_embed( | 77 | def add_embed( |
@@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
74 | initializer = self.get_embed(initializer) | 96 | initializer = self.get_embed(initializer) |
75 | 97 | ||
76 | initializer = initializer.to( | 98 | initializer = initializer.to( |
77 | device=self.temp_token_embedding.weight.device, | 99 | device=self.token_embedding.weight.device, |
78 | dtype=self.temp_token_embedding.weight.dtype, | 100 | dtype=self.token_embedding.weight.dtype, |
79 | ) | 101 | ) |
80 | 102 | ||
81 | if initializer_noise != 0: | 103 | if initializer_noise != 0: |
@@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
84 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 106 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
85 | 107 | ||
86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
87 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
88 | self.token_embedding.weight.data[token_ids] = initializer | 109 | self.token_embedding.weight.data[token_ids] = initializer |
89 | 110 | ||
90 | def load_embed(self, input_ids: list[int], filename: Path): | 111 | def load_embed(self, input_ids: list[int], filename: Path): |
@@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
95 | save_file({"embed": self.get_embed(input_ids)}, filename) | 116 | save_file({"embed": self.get_embed(input_ids)}, filename) |
96 | 117 | ||
97 | def persist(self): | 118 | def persist(self): |
98 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( |
120 | self.token_embedding.weight.data[self.temp_token_ids] | ||
121 | ) | ||
122 | self.overlay.reset() | ||
99 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
100 | 124 | ||
101 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 125 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
@@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
103 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
104 | 128 | ||
105 | embeds = self.token_embedding(input_ids) | 129 | embeds = self.token_embedding(input_ids) |
106 | |||
107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
108 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 131 | embeds[mask] += self.overlay(embeds[mask]) |
109 | 132 | ||
110 | return embeds | 133 | return embeds |
111 | 134 | ||