diff options
-rw-r--r-- | models/clip/embeddings.py | 52 | ||||
-rw-r--r-- | models/sparse.py | 57 | ||||
-rw-r--r-- | train_ti.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 8 |
4 files changed, 83 insertions, 36 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d8343a0..a356434 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -11,6 +11,8 @@ from transformers import CLIPTextModel | |||
11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | from models.sparse import PseudoSparseEmbedding | ||
15 | |||
14 | 16 | ||
15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | 17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: |
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | 18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape |
@@ -41,18 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
41 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
42 | self.position_embedding = embeddings.position_embedding | 44 | self.position_embedding = embeddings.position_embedding |
43 | self.initializer_factor = config.initializer_factor | 45 | self.initializer_factor = config.initializer_factor |
44 | self.alpha = alpha | ||
45 | 46 | ||
46 | self.temp_token_embedding = nn.ParameterList() | 47 | self.token_override_embedding = PseudoSparseEmbedding( |
47 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 48 | self.token_embedding.embedding_dim, |
49 | device=self.token_embedding.weight.device, | ||
50 | dtype=self.token_embedding.weight.dtype, | ||
51 | ) | ||
52 | self.alpha = alpha | ||
48 | 53 | ||
49 | def resize(self, size: int): | 54 | def resize(self, size: int): |
50 | for _ in range(len(self.temp_token_embedding), size): | 55 | self.token_override_embedding.resize(size) |
51 | self.temp_token_embedding.append(torch.zeros( | ||
52 | self.token_embedding.embedding_dim, | ||
53 | device=self.token_embedding.weight.device, | ||
54 | dtype=self.token_embedding.weight.dtype, | ||
55 | )) | ||
56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
57 | 57 | ||
58 | def add_embed( | 58 | def add_embed( |
@@ -86,8 +86,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
86 | 86 | ||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 88 | ||
89 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
90 | self.token_embedding.weight.data[token_ids] = initializer | 89 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids) | ||
91 | 91 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 92 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 93 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -97,33 +97,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 97 | save_file({"embed": self.get_embed(input_ids)}, filename) |
98 | 98 | ||
99 | def persist(self): | 99 | def persist(self): |
100 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): | 100 | input_ids = torch.arange(self.token_embedding.num_embeddings) |
101 | self.token_embedding.weight.data[id] += self.alpha * emb | 101 | embs, mask = self.token_override_embedding(input_ids) |
102 | nn.init.zeros_(emb) | 102 | if embs is not None: |
103 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 103 | input_ids = input_ids[mask] |
104 | self.token_embedding.weight.data[input_ids] += self.alpha * embs | ||
105 | self.token_override_embedding.unset(input_ids) | ||
104 | 106 | ||
105 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 107 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
106 | if isinstance(input_ids, list): | 108 | if isinstance(input_ids, list): |
107 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 109 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
108 | 110 | ||
109 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | 111 | embs = self.token_embedding(input_ids) |
110 | 112 | embs_override, mask = self.token_override_embedding(input_ids) | |
111 | embeds = self.token_embedding(input_ids) | 113 | if embs_override is not None: |
112 | mask = torch.isin(input_ids, all_temp_token_ids) | 114 | embs[mask] += self.alpha * embs_override |
113 | temp_token_ids = input_ids[mask] | ||
114 | |||
115 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
116 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
117 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
118 | |||
119 | if len(temp_token_ids): | ||
120 | embeds_override = torch.stack([ | ||
121 | self.temp_token_embedding[id] | ||
122 | for id in temp_token_ids | ||
123 | ]) | ||
124 | embeds[mask] += self.alpha * embeds_override | ||
125 | 115 | ||
126 | return embeds | 116 | return embs |
127 | 117 | ||
128 | def forward( | 118 | def forward( |
129 | self, | 119 | self, |
diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..0b15454 --- /dev/null +++ b/models/sparse.py | |||
@@ -0,0 +1,57 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | |||
7 | class PseudoSparseEmbedding(nn.Module): | ||
8 | def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): | ||
9 | super().__init__() | ||
10 | |||
11 | self.embedding_dim = embedding_dim | ||
12 | self.dtype = dtype | ||
13 | self.params = nn.ParameterList() | ||
14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | ||
15 | |||
16 | def forward(self, input_ids: Optional[torch.LongTensor] = None): | ||
17 | if input_ids is None: | ||
18 | input_ids = torch.arange(self.mapping.shape[0]) | ||
19 | |||
20 | ids = self.mapping[input_ids.to(self.mapping.device)] | ||
21 | mask = ~(ids == -1) | ||
22 | |||
23 | if torch.all(~mask): | ||
24 | embs = None | ||
25 | else: | ||
26 | embs = torch.stack([self.params[id] for id in ids[mask]]) | ||
27 | |||
28 | return embs, mask | ||
29 | |||
30 | def resize(self, new_num_embeddings: int): | ||
31 | old_num_embeddings = self.mapping.shape[0] | ||
32 | n = min(old_num_embeddings, new_num_embeddings) | ||
33 | |||
34 | new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 | ||
35 | new_mapping[:n] = self.mapping[:n] | ||
36 | |||
37 | self.mapping = new_mapping | ||
38 | |||
39 | def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): | ||
40 | if len(input_ids.shape) != 0: | ||
41 | if tensor is not None: | ||
42 | return [self.set(id, t) for id, t in zip(input_ids, tensor)] | ||
43 | else: | ||
44 | return [self.set(id) for id in input_ids] | ||
45 | |||
46 | id = self.mapping[input_ids] | ||
47 | |||
48 | if id == -1: | ||
49 | id = len(self.params) | ||
50 | self.mapping[input_ids] = id | ||
51 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | ||
52 | |||
53 | self.params[id] = tensor if tensor is not None else torch.zeros( | ||
54 | self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
55 | |||
56 | def unset(self, input_ids: torch.LongTensor): | ||
57 | self.mapping[input_ids] = -1 | ||
diff --git a/train_ti.py b/train_ti.py index 0ad7574..a9a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -809,7 +809,7 @@ def main(): | |||
809 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 809 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
810 | 810 | ||
811 | optimizer = create_optimizer( | 811 | optimizer = create_optimizer( |
812 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 812 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
813 | lr=args.learning_rate, | 813 | lr=args.learning_rate, |
814 | ) | 814 | ) |
815 | 815 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 16baa34..95128da 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks( | |||
69 | 69 | ||
70 | if use_ema: | 70 | if use_ema: |
71 | ema_embeddings = EMAModel( | 71 | ema_embeddings = EMAModel( |
72 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 72 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
73 | inv_gamma=ema_inv_gamma, | 73 | inv_gamma=ema_inv_gamma, |
74 | power=ema_power, | 74 | power=ema_power, |
75 | max_value=ema_max_decay, | 75 | max_value=ema_max_decay, |
@@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks( | |||
81 | def ema_context(): | 81 | def ema_context(): |
82 | if ema_embeddings is not None: | 82 | if ema_embeddings is not None: |
83 | return ema_embeddings.apply_temporary( | 83 | return ema_embeddings.apply_temporary( |
84 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 84 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters() |
85 | ) | 85 | ) |
86 | else: | 86 | else: |
87 | return nullcontext() | 87 | return nullcontext() |
88 | 88 | ||
89 | def on_accum_model(): | 89 | def on_accum_model(): |
90 | return text_encoder.text_model.embeddings.temp_token_embedding | 90 | return text_encoder.text_model.embeddings.token_override_embedding.params |
91 | 91 | ||
92 | @contextmanager | 92 | @contextmanager |
93 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | @torch.no_grad() | 104 | @torch.no_grad() |
105 | def on_after_optimize(zero_ids, lr: float): | 105 | def on_after_optimize(zero_ids, lr: float): |
106 | if ema_embeddings is not None: | 106 | if ema_embeddings is not None: |
107 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
108 | 108 | ||
109 | def on_log(): | 109 | def on_log(): |
110 | if ema_embeddings is not None: | 110 | if ema_embeddings is not None: |