summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
commit2e654c017780d37f3304436e2feb84b619f1c023 (patch)
tree8a248fe17c3512110de9fcfed7f7bfd708b3b8da
parentTI: Delta learning (diff)
downloadtextual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.gz
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.bz2
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.zip
Improved sparse embeddings
-rw-r--r--models/clip/embeddings.py52
-rw-r--r--models/sparse.py57
-rw-r--r--train_ti.py2
-rw-r--r--training/strategy/ti.py8
4 files changed, 83 insertions, 36 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index d8343a0..a356434 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -11,6 +11,8 @@ from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig 11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14from models.sparse import PseudoSparseEmbedding
15
14 16
15def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: 17def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape 18 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
@@ -41,18 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
41 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
42 self.position_embedding = embeddings.position_embedding 44 self.position_embedding = embeddings.position_embedding
43 self.initializer_factor = config.initializer_factor 45 self.initializer_factor = config.initializer_factor
44 self.alpha = alpha
45 46
46 self.temp_token_embedding = nn.ParameterList() 47 self.token_override_embedding = PseudoSparseEmbedding(
47 self.temp_token_ids = torch.tensor([], dtype=torch.long) 48 self.token_embedding.embedding_dim,
49 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype,
51 )
52 self.alpha = alpha
48 53
49 def resize(self, size: int): 54 def resize(self, size: int):
50 for _ in range(len(self.temp_token_embedding), size): 55 self.token_override_embedding.resize(size)
51 self.temp_token_embedding.append(torch.zeros(
52 self.token_embedding.embedding_dim,
53 device=self.token_embedding.weight.device,
54 dtype=self.token_embedding.weight.dtype,
55 ))
56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
57 57
58 def add_embed( 58 def add_embed(
@@ -86,8 +86,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
86 86
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 87 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 88
89 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
90 self.token_embedding.weight.data[token_ids] = initializer 89 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids)
91 91
92 def load_embed(self, input_ids: list[int], filename: Path): 92 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 93 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -97,33 +97,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 97 save_file({"embed": self.get_embed(input_ids)}, filename)
98 98
99 def persist(self): 99 def persist(self):
100 for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): 100 input_ids = torch.arange(self.token_embedding.num_embeddings)
101 self.token_embedding.weight.data[id] += self.alpha * emb 101 embs, mask = self.token_override_embedding(input_ids)
102 nn.init.zeros_(emb) 102 if embs is not None:
103 self.temp_token_ids = torch.tensor([], dtype=torch.long) 103 input_ids = input_ids[mask]
104 self.token_embedding.weight.data[input_ids] += self.alpha * embs
105 self.token_override_embedding.unset(input_ids)
104 106
105 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 107 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
106 if isinstance(input_ids, list): 108 if isinstance(input_ids, list):
107 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 109 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
108 110
109 all_temp_token_ids = self.temp_token_ids.to(input_ids.device) 111 embs = self.token_embedding(input_ids)
110 112 embs_override, mask = self.token_override_embedding(input_ids)
111 embeds = self.token_embedding(input_ids) 113 if embs_override is not None:
112 mask = torch.isin(input_ids, all_temp_token_ids) 114 embs[mask] += self.alpha * embs_override
113 temp_token_ids = input_ids[mask]
114
115 temp_token_ids = temp_token_ids.unsqueeze(1)
116 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
117 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
118
119 if len(temp_token_ids):
120 embeds_override = torch.stack([
121 self.temp_token_embedding[id]
122 for id in temp_token_ids
123 ])
124 embeds[mask] += self.alpha * embeds_override
125 115
126 return embeds 116 return embs
127 117
128 def forward( 118 def forward(
129 self, 119 self,
diff --git a/models/sparse.py b/models/sparse.py
new file mode 100644
index 0000000..0b15454
--- /dev/null
+++ b/models/sparse.py
@@ -0,0 +1,57 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, device=None, dtype=torch.float32):
9 super().__init__()
10
11 self.embedding_dim = embedding_dim
12 self.dtype = dtype
13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long)
15
16 def forward(self, input_ids: Optional[torch.LongTensor] = None):
17 if input_ids is None:
18 input_ids = torch.arange(self.mapping.shape[0])
19
20 ids = self.mapping[input_ids.to(self.mapping.device)]
21 mask = ~(ids == -1)
22
23 if torch.all(~mask):
24 embs = None
25 else:
26 embs = torch.stack([self.params[id] for id in ids[mask]])
27
28 return embs, mask
29
30 def resize(self, new_num_embeddings: int):
31 old_num_embeddings = self.mapping.shape[0]
32 n = min(old_num_embeddings, new_num_embeddings)
33
34 new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
35 new_mapping[:n] = self.mapping[:n]
36
37 self.mapping = new_mapping
38
39 def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
40 if len(input_ids.shape) != 0:
41 if tensor is not None:
42 return [self.set(id, t) for id, t in zip(input_ids, tensor)]
43 else:
44 return [self.set(id) for id in input_ids]
45
46 id = self.mapping[input_ids]
47
48 if id == -1:
49 id = len(self.params)
50 self.mapping[input_ids] = id
51 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
52
53 self.params[id] = tensor if tensor is not None else torch.zeros(
54 self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
55
56 def unset(self, input_ids: torch.LongTensor):
57 self.mapping[input_ids] = -1
diff --git a/train_ti.py b/train_ti.py
index 0ad7574..a9a2333 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -809,7 +809,7 @@ def main():
809 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 809 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
810 810
811 optimizer = create_optimizer( 811 optimizer = create_optimizer(
812 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 812 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
813 lr=args.learning_rate, 813 lr=args.learning_rate,
814 ) 814 )
815 815
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 16baa34..95128da 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks(
69 69
70 if use_ema: 70 if use_ema:
71 ema_embeddings = EMAModel( 71 ema_embeddings = EMAModel(
72 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 72 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
73 inv_gamma=ema_inv_gamma, 73 inv_gamma=ema_inv_gamma,
74 power=ema_power, 74 power=ema_power,
75 max_value=ema_max_decay, 75 max_value=ema_max_decay,
@@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks(
81 def ema_context(): 81 def ema_context():
82 if ema_embeddings is not None: 82 if ema_embeddings is not None:
83 return ema_embeddings.apply_temporary( 83 return ema_embeddings.apply_temporary(
84 text_encoder.text_model.embeddings.temp_token_embedding.parameters() 84 text_encoder.text_model.embeddings.token_override_embedding.params.parameters()
85 ) 85 )
86 else: 86 else:
87 return nullcontext() 87 return nullcontext()
88 88
89 def on_accum_model(): 89 def on_accum_model():
90 return text_encoder.text_model.embeddings.temp_token_embedding 90 return text_encoder.text_model.embeddings.token_override_embedding.params
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(epoch: int):
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 @torch.no_grad() 104 @torch.no_grad()
105 def on_after_optimize(zero_ids, lr: float): 105 def on_after_optimize(zero_ids, lr: float):
106 if ema_embeddings is not None: 106 if ema_embeddings is not None:
107 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 107 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
108 108
109 def on_log(): 109 def on_log():
110 if ema_embeddings is not None: 110 if ema_embeddings is not None: