summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 22:13:55 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 22:13:55 +0200
commit208e48134e324e934ad964bdc61880cc923f4c0d (patch)
treec215f6c201c04b0b2d18ba0df230fb4c5e622985 /models/clip/embeddings.py
parentFix (diff)
downloadtextual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.gz
textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.bz2
textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.zip
Revert
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py42
1 files changed, 4 insertions, 38 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index c9c788c..1e21965 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
31 return new_embedding 31 return new_embedding
32 32
33 33
34class OverlayLinear(nn.Module):
35 def __init__(self, in_features, out_features, rank=4):
36 super().__init__()
37
38 if rank > min(in_features, out_features):
39 raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}")
40
41 self.rank = rank
42 self.down = nn.Linear(in_features, rank, bias=False)
43 self.up = nn.Linear(rank, out_features, bias=False)
44 self.reset()
45
46 def reset(self):
47 nn.init.normal_(self.down.weight, std=1 / self.rank)
48 nn.init.zeros_(self.up.weight)
49
50 def forward(self, hidden_states):
51 orig_dtype = hidden_states.dtype
52 dtype = self.down.weight.dtype
53
54 down_hidden_states = self.down(hidden_states.to(dtype))
55 up_hidden_states = self.up(down_hidden_states)
56
57 return up_hidden_states.to(orig_dtype)
58
59
60class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 34class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
61 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): 35 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4):
62 super().__init__(config) 36 super().__init__(config)
63 37
64 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
65 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
66 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.alpha = alpha
67 42
68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
69 self.temp_token_embedding = nn.Embedding( 43 self.temp_token_embedding = nn.Embedding(
70 self.token_embedding.num_embeddings, 44 self.token_embedding.num_embeddings,
71 self.token_embedding.embedding_dim, 45 self.token_embedding.embedding_dim,
@@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
75 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() 49 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
76 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
77 51
78 def reset_overlay(self):
79 self.overlay.reset()
80
81 def resize(self, size: int): 52 def resize(self, size: int):
82 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 53 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
83 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 54 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
@@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
125 save_file({"embed": self.get_embed(input_ids)}, filename) 96 save_file({"embed": self.get_embed(input_ids)}, filename)
126 97
127 def persist(self): 98 def persist(self):
128 embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] 99 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
129 self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds)
130 self.overlay.reset()
131 self.temp_token_ids = torch.tensor([], dtype=torch.long) 100 self.temp_token_ids = torch.tensor([], dtype=torch.long)
132 101
133 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 102 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
@@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
135 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 104 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
136 105
137 embeds = self.token_embedding(input_ids) 106 embeds = self.token_embedding(input_ids)
138
139 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 107 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
140 108 embeds[mask] = self.temp_token_embedding(input_ids[mask])
141 temp_embeds = self.temp_token_embedding(input_ids[mask])
142 embeds[mask] = temp_embeds + self.overlay(temp_embeds)
143 109
144 return embeds 110 return embeds
145 111