summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py40
-rw-r--r--training/strategy/ti.py18
2 files changed, 31 insertions, 27 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 870ee49..95904cf 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
42 self.init_temp_embeddings() 42 self.init_temp_embeddings()
43 43
44 def init_temp_embeddings(self): 44 def init_temp_embeddings(self):
45 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.ParameterList()
46 0,
47 self.token_embedding.embedding_dim,
48 device=self.token_embedding.weight.device,
49 dtype=self.token_embedding.weight.dtype
50 )
51 self.temp_token_ids = torch.tensor([], dtype=torch.long) 46 self.temp_token_ids = torch.tensor([], dtype=torch.long)
52 47
53 def resize(self, size: int): 48 def resize(self, size: int):
54 self.temp_token_embedding = resize_embedding( 49 for _ in range(len(self.temp_token_embedding), size):
55 self.temp_token_embedding, 50 self.temp_token_embedding.append(torch.zeros(
56 size - self.num_permanent_embeddings, 51 self.token_embedding.embedding_dim,
57 self.initializer_factor 52 device=self.token_embedding.weight.device,
58 ) 53 dtype=self.token_embedding.weight.dtype,
54 ))
59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 55 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
60 56
61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 57 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 with torch.no_grad(): 70 with torch.no_grad():
75 initializer = self.get_embed(initializer) 71 initializer = self.get_embed(initializer)
76 72
73 initializer = initializer.to(
74 device=self.token_embedding.weight.device,
75 dtype=self.token_embedding.weight.dtype,
76 )
77
77 token_ids = torch.tensor(token_ids, dtype=torch.long) 78 token_ids = torch.tensor(token_ids, dtype=torch.long)
78 79
79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 80 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
80 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) 81 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1)
81 self.temp_token_embedding.weight.data[mask] = initializer.to( 82
82 device=self.temp_token_embedding.weight.device, 83 for i, id in enumerate(mask):
83 dtype=self.temp_token_embedding.weight.dtype, 84 self.temp_token_embedding[id] = initializer[i]
84 )
85 85
86 def load_embed(self, input_ids: list[int], filename: Path): 86 def load_embed(self, input_ids: list[int], filename: Path):
87 with safe_open(filename, framework="pt", device="cpu") as file: 87 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
91 save_file({"embed": self.get_embed(input_ids)}, filename) 91 save_file({"embed": self.get_embed(input_ids)}, filename)
92 92
93 def persist(self): 93 def persist(self):
94 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 94 for id, emb in zip(self.temp_token_ids, self.temp_token_embedding):
95 self.token_embedding.weight.data[id] = emb
95 self.num_permanent_embeddings = self.token_embedding.num_embeddings 96 self.num_permanent_embeddings = self.token_embedding.num_embeddings
96 self.init_temp_embeddings() 97 self.init_temp_embeddings()
97 98
@@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
110 all_temp_token_ids = all_temp_token_ids.unsqueeze(0) 111 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
111 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() 112 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
112 113
113 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) 114 if len(temp_token_ids):
115 embeds_override = torch.stack([
116 self.temp_token_embedding[id]
117 for id in temp_token_ids
118 ])
119 embeds[embeds_mask] = embeds_override
114 120
115 return embeds 121 return embeds
116 122
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..7ac5011 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks(
108 @torch.no_grad() 108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay: 110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 111 return torch.stack([
112 return torch.all(w.grad == 0, dim=1) 112 t
113 for t in text_encoder.text_model.embeddings.temp_token_embedding
114 if t.grad is not None
115 ])
113 116
114 @torch.no_grad() 117 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 118 def on_after_optimize(w, lr: float):
116 if ema_embeddings is not None: 119 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118 121
@@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks(
120 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
121 124
122 if lambda_ != 0: 125 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 126 norm = w[:, :].norm(dim=-1, keepdim=True)
124 127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130 128
131 def on_log(): 129 def on_log():
132 if ema_embeddings is not None: 130 if ema_embeddings is not None: