summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py34
-rw-r--r--train_ti.py16
2 files changed, 33 insertions, 17 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2d60c28..e8cc865 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.init_temp_embeddings()
41 42
43 def init_temp_embeddings(self):
42 self.temp_token_embedding = nn.Embedding( 44 self.temp_token_embedding = nn.Embedding(
43 self.token_embedding.num_embeddings, 45 0,
44 self.token_embedding.embedding_dim, 46 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device, 47 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 48 dtype=self.token_embedding.weight.dtype
47 ) 49 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 51
51 def resize(self, size: int): 52 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,9 +74,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 ) 74 )
75 75
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
77
78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 77 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
79 self.temp_token_embedding.weight.data[token_ids] = initializer 78
79 self.temp_token_embedding = resize_embedding(
80 self.temp_token_embedding,
81 self.temp_token_ids.shape[0],
82 self.initializer_factor
83 )
84
85 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1)
86 self.temp_token_embedding.weight.data[mask] = initializer
87 self.token_embedding.weight.data[token_ids] = initializer
80 88
81 def load_embed(self, input_ids: list[int], filename: Path): 89 def load_embed(self, input_ids: list[int], filename: Path):
82 with safe_open(filename, framework="pt", device="cpu") as file: 90 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -86,17 +94,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
86 save_file({"embed": self.get_embed(input_ids)}, filename) 94 save_file({"embed": self.get_embed(input_ids)}, filename)
87 95
88 def persist(self): 96 def persist(self):
89 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 97 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:]
90 self.temp_token_ids = torch.tensor([], dtype=torch.long) 98 self.init_temp_embeddings()
91 99
92 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
93 if isinstance(input_ids, list): 101 if isinstance(input_ids, list):
94 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
95 103
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
96 embeds = self.token_embedding(input_ids) 106 embeds = self.token_embedding(input_ids)
97 107
98 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 108 embeds_mask = torch.isin(input_ids, all_temp_token_ids)
99 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 109 temp_token_ids = input_ids[embeds_mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
100 116
101 return embeds 117 return embeds
102 118
diff --git a/train_ti.py b/train_ti.py
index 9ae8d1b..e4fd464 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -588,14 +588,6 @@ def main():
588 unet.enable_gradient_checkpointing() 588 unet.enable_gradient_checkpointing()
589 text_encoder.gradient_checkpointing_enable() 589 text_encoder.gradient_checkpointing_enable()
590 590
591 if args.embeddings_dir is not None:
592 embeddings_dir = Path(args.embeddings_dir)
593 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
594 raise ValueError("--embeddings_dir must point to an existing directory")
595
596 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
597 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
598
599 if len(args.alias_tokens) != 0: 591 if len(args.alias_tokens) != 0:
600 alias_placeholder_tokens = args.alias_tokens[::2] 592 alias_placeholder_tokens = args.alias_tokens[::2]
601 alias_initializer_tokens = args.alias_tokens[1::2] 593 alias_initializer_tokens = args.alias_tokens[1::2]
@@ -609,6 +601,14 @@ def main():
609 embeddings.persist() 601 embeddings.persist()
610 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 602 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
611 603
604 if args.embeddings_dir is not None:
605 embeddings_dir = Path(args.embeddings_dir)
606 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
607 raise ValueError("--embeddings_dir must point to an existing directory")
608
609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
611
612 if args.scale_lr: 612 if args.scale_lr:
613 args.learning_rate = ( 613 args.learning_rate = (
614 args.learning_rate * args.gradient_accumulation_steps * 614 args.learning_rate * args.gradient_accumulation_steps *