diff options
author | Volpeon <git@volpeon.ink> | 2023-03-27 07:15:46 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-27 07:15:46 +0200 |
commit | 0e4c36889aa6b7ec13320a03728118c7c1a8e716 (patch) | |
tree | 461e63354dac6ab5b68d0f57e1569798df5bf202 | |
parent | Fix TI embeddings init (diff) | |
download | textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.gz textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.bz2 textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.zip |
Sparse TI embeddings without sparse tensors
-rw-r--r-- | models/clip/embeddings.py | 40 | ||||
-rw-r--r-- | training/strategy/ti.py | 18 |
2 files changed, 31 insertions, 27 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 870ee49..95904cf 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
42 | self.init_temp_embeddings() | 42 | self.init_temp_embeddings() |
43 | 43 | ||
44 | def init_temp_embeddings(self): | 44 | def init_temp_embeddings(self): |
45 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.ParameterList() |
46 | 0, | ||
47 | self.token_embedding.embedding_dim, | ||
48 | device=self.token_embedding.weight.device, | ||
49 | dtype=self.token_embedding.weight.dtype | ||
50 | ) | ||
51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 46 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
52 | 47 | ||
53 | def resize(self, size: int): | 48 | def resize(self, size: int): |
54 | self.temp_token_embedding = resize_embedding( | 49 | for _ in range(len(self.temp_token_embedding), size): |
55 | self.temp_token_embedding, | 50 | self.temp_token_embedding.append(torch.zeros( |
56 | size - self.num_permanent_embeddings, | 51 | self.token_embedding.embedding_dim, |
57 | self.initializer_factor | 52 | device=self.token_embedding.weight.device, |
58 | ) | 53 | dtype=self.token_embedding.weight.dtype, |
54 | )) | ||
59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 55 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
60 | 56 | ||
61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 57 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
@@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
74 | with torch.no_grad(): | 70 | with torch.no_grad(): |
75 | initializer = self.get_embed(initializer) | 71 | initializer = self.get_embed(initializer) |
76 | 72 | ||
73 | initializer = initializer.to( | ||
74 | device=self.token_embedding.weight.device, | ||
75 | dtype=self.token_embedding.weight.dtype, | ||
76 | ) | ||
77 | |||
77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 78 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
78 | 79 | ||
79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 80 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 81 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 82 | |
82 | device=self.temp_token_embedding.weight.device, | 83 | for i, id in enumerate(mask): |
83 | dtype=self.temp_token_embedding.weight.dtype, | 84 | self.temp_token_embedding[id] = initializer[i] |
84 | ) | ||
85 | 85 | ||
86 | def load_embed(self, input_ids: list[int], filename: Path): | 86 | def load_embed(self, input_ids: list[int], filename: Path): |
87 | with safe_open(filename, framework="pt", device="cpu") as file: | 87 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
91 | save_file({"embed": self.get_embed(input_ids)}, filename) | 91 | save_file({"embed": self.get_embed(input_ids)}, filename) |
92 | 92 | ||
93 | def persist(self): | 93 | def persist(self): |
94 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 94 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): |
95 | self.token_embedding.weight.data[id] = emb | ||
95 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 96 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
96 | self.init_temp_embeddings() | 97 | self.init_temp_embeddings() |
97 | 98 | ||
@@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
110 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | 111 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) |
111 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | 112 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() |
112 | 113 | ||
113 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | 114 | if len(temp_token_ids): |
115 | embeds_override = torch.stack([ | ||
116 | self.temp_token_embedding[id] | ||
117 | for id in temp_token_ids | ||
118 | ]) | ||
119 | embeds[embeds_mask] = embeds_override | ||
114 | 120 | ||
115 | return embeds | 121 | return embeds |
116 | 122 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..7ac5011 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks( | |||
108 | @torch.no_grad() | 108 | @torch.no_grad() |
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(lr: float, epoch: int): |
110 | if use_emb_decay: | 110 | if use_emb_decay: |
111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 111 | return torch.stack([ |
112 | return torch.all(w.grad == 0, dim=1) | 112 | t |
113 | for t in text_encoder.text_model.embeddings.temp_token_embedding | ||
114 | if t.grad is not None | ||
115 | ]) | ||
113 | 116 | ||
114 | @torch.no_grad() | 117 | @torch.no_grad() |
115 | def on_after_optimize(zero_ids, lr: float): | 118 | def on_after_optimize(w, lr: float): |
116 | if ema_embeddings is not None: | 119 | if ema_embeddings is not None: |
117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 120 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
118 | 121 | ||
@@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks( | |||
120 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
121 | 124 | ||
122 | if lambda_ != 0: | 125 | if lambda_ != 0: |
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 126 | norm = w[:, :].norm(dim=-1, keepdim=True) |
124 | 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | |
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
130 | 128 | ||
131 | def on_log(): | 129 | def on_log(): |
132 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |