summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 16:30:36 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 16:30:36 +0200
commitc96073646bbb638d7d78fdd7d9fdeed08d1454b5 (patch)
tree3e0846964fa127844d652e2dee081cd67e672e6a /models
parentUpdate (diff)
downloadtextual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.gz
textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.bz2
textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.zip
Experimental: TI via LoRA
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py53
1 files changed, 38 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9abd1bb..88e0cc0 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
31 return new_embedding 31 return new_embedding
32 32
33 33
34class OverlayLinear(nn.Module):
35 def __init__(self, in_features, out_features, rank=4):
36 super().__init__()
37
38 if rank > min(in_features, out_features):
39 raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}")
40
41 self.rank = rank
42 self.down = nn.Linear(in_features, rank, bias=False)
43 self.up = nn.Linear(rank, out_features, bias=False)
44 self.reset()
45
46 def reset(self):
47 nn.init.normal_(self.down.weight, std=1 / self.rank)
48 nn.init.zeros_(self.up.weight)
49
50 def forward(self, hidden_states):
51 orig_dtype = hidden_states.dtype
52 dtype = self.down.weight.dtype
53
54 down_hidden_states = self.down(hidden_states.to(dtype))
55 up_hidden_states = self.up(down_hidden_states)
56
57 return up_hidden_states.to(orig_dtype)
58
59
34class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 60class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
35 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): 61 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128):
36 super().__init__(config) 62 super().__init__(config)
37 63
38 self.token_embedding = embeddings.token_embedding 64 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 65 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 66 self.initializer_factor = config.initializer_factor
41 67
42 self.temp_token_embedding = nn.Embedding( 68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
43 self.token_embedding.num_embeddings,
44 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype
47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 69 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 70
71 def reset_overlay(self):
72 self.overlay.reset()
73
51 def resize(self, size: int): 74 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 75 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 76
55 def add_embed( 77 def add_embed(
@@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 initializer = self.get_embed(initializer) 96 initializer = self.get_embed(initializer)
75 97
76 initializer = initializer.to( 98 initializer = initializer.to(
77 device=self.temp_token_embedding.weight.device, 99 device=self.token_embedding.weight.device,
78 dtype=self.temp_token_embedding.weight.dtype, 100 dtype=self.token_embedding.weight.dtype,
79 ) 101 )
80 102
81 if initializer_noise != 0: 103 if initializer_noise != 0:
@@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
84 token_ids = torch.tensor(token_ids, dtype=torch.long) 106 token_ids = torch.tensor(token_ids, dtype=torch.long)
85 107
86 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 108 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
87 self.temp_token_embedding.weight.data[token_ids] = initializer
88 self.token_embedding.weight.data[token_ids] = initializer 109 self.token_embedding.weight.data[token_ids] = initializer
89 110
90 def load_embed(self, input_ids: list[int], filename: Path): 111 def load_embed(self, input_ids: list[int], filename: Path):
@@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
95 save_file({"embed": self.get_embed(input_ids)}, filename) 116 save_file({"embed": self.get_embed(input_ids)}, filename)
96 117
97 def persist(self): 118 def persist(self):
98 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 119 self.token_embedding.weight.data[self.temp_token_ids] += self.overlay(
120 self.token_embedding.weight.data[self.temp_token_ids]
121 )
122 self.overlay.reset()
99 self.temp_token_ids = torch.tensor([], dtype=torch.long) 123 self.temp_token_ids = torch.tensor([], dtype=torch.long)
100 124
101 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 125 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
@@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
103 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 127 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
104 128
105 embeds = self.token_embedding(input_ids) 129 embeds = self.token_embedding(input_ids)
106
107 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 130 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
108 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 131 embeds[mask] += self.overlay(embeds[mask])
109 132
110 return embeds 133 return embeds
111 134