diff options
| -rw-r--r-- | models/clip/embeddings.py | 40 | ||||
| -rw-r--r-- | training/strategy/ti.py | 18 |
2 files changed, 31 insertions, 27 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 870ee49..95904cf 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 42 | self.init_temp_embeddings() | 42 | self.init_temp_embeddings() |
| 43 | 43 | ||
| 44 | def init_temp_embeddings(self): | 44 | def init_temp_embeddings(self): |
| 45 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.ParameterList() |
| 46 | 0, | ||
| 47 | self.token_embedding.embedding_dim, | ||
| 48 | device=self.token_embedding.weight.device, | ||
| 49 | dtype=self.token_embedding.weight.dtype | ||
| 50 | ) | ||
| 51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 46 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 52 | 47 | ||
| 53 | def resize(self, size: int): | 48 | def resize(self, size: int): |
| 54 | self.temp_token_embedding = resize_embedding( | 49 | for _ in range(len(self.temp_token_embedding), size): |
| 55 | self.temp_token_embedding, | 50 | self.temp_token_embedding.append(torch.zeros( |
| 56 | size - self.num_permanent_embeddings, | 51 | self.token_embedding.embedding_dim, |
| 57 | self.initializer_factor | 52 | device=self.token_embedding.weight.device, |
| 58 | ) | 53 | dtype=self.token_embedding.weight.dtype, |
| 54 | )) | ||
| 59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 55 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 60 | 56 | ||
| 61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 57 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| @@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | with torch.no_grad(): | 70 | with torch.no_grad(): |
| 75 | initializer = self.get_embed(initializer) | 71 | initializer = self.get_embed(initializer) |
| 76 | 72 | ||
| 73 | initializer = initializer.to( | ||
| 74 | device=self.token_embedding.weight.device, | ||
| 75 | dtype=self.token_embedding.weight.dtype, | ||
| 76 | ) | ||
| 77 | |||
| 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 78 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 78 | 79 | ||
| 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 80 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 81 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
| 81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 82 | |
| 82 | device=self.temp_token_embedding.weight.device, | 83 | for i, id in enumerate(mask): |
| 83 | dtype=self.temp_token_embedding.weight.dtype, | 84 | self.temp_token_embedding[id] = initializer[i] |
| 84 | ) | ||
| 85 | 85 | ||
| 86 | def load_embed(self, input_ids: list[int], filename: Path): | 86 | def load_embed(self, input_ids: list[int], filename: Path): |
| 87 | with safe_open(filename, framework="pt", device="cpu") as file: | 87 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 91 | save_file({"embed": self.get_embed(input_ids)}, filename) | 91 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 92 | 92 | ||
| 93 | def persist(self): | 93 | def persist(self): |
| 94 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 94 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): |
| 95 | self.token_embedding.weight.data[id] = emb | ||
| 95 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 96 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
| 96 | self.init_temp_embeddings() | 97 | self.init_temp_embeddings() |
| 97 | 98 | ||
| @@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 110 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | 111 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) |
| 111 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | 112 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() |
| 112 | 113 | ||
| 113 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | 114 | if len(temp_token_ids): |
| 115 | embeds_override = torch.stack([ | ||
| 116 | self.temp_token_embedding[id] | ||
| 117 | for id in temp_token_ids | ||
| 118 | ]) | ||
| 119 | embeds[embeds_mask] = embeds_override | ||
| 114 | 120 | ||
| 115 | return embeds | 121 | return embeds |
| 116 | 122 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..7ac5011 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks( | |||
| 108 | @torch.no_grad() | 108 | @torch.no_grad() |
| 109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(lr: float, epoch: int): |
| 110 | if use_emb_decay: | 110 | if use_emb_decay: |
| 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 111 | return torch.stack([ |
| 112 | return torch.all(w.grad == 0, dim=1) | 112 | t |
| 113 | for t in text_encoder.text_model.embeddings.temp_token_embedding | ||
| 114 | if t.grad is not None | ||
| 115 | ]) | ||
| 113 | 116 | ||
| 114 | @torch.no_grad() | 117 | @torch.no_grad() |
| 115 | def on_after_optimize(zero_ids, lr: float): | 118 | def on_after_optimize(w, lr: float): |
| 116 | if ema_embeddings is not None: | 119 | if ema_embeddings is not None: |
| 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 120 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
| 118 | 121 | ||
| @@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks( | |||
| 120 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
| 121 | 124 | ||
| 122 | if lambda_ != 0: | 125 | if lambda_ != 0: |
| 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 126 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 124 | 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | |
| 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
| 126 | mask[zero_ids] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 130 | 128 | ||
| 131 | def on_log(): | 129 | def on_log(): |
| 132 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |
