diff options
-rw-r--r-- | models/clip/embeddings.py | 34 | ||||
-rw-r--r-- | train_ti.py | 16 |
2 files changed, 33 insertions, 17 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2d60c28..e8cc865 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
40 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
41 | self.init_temp_embeddings() | ||
41 | 42 | ||
43 | def init_temp_embeddings(self): | ||
42 | self.temp_token_embedding = nn.Embedding( | 44 | self.temp_token_embedding = nn.Embedding( |
43 | self.token_embedding.num_embeddings, | 45 | 0, |
44 | self.token_embedding.embedding_dim, | 46 | self.token_embedding.embedding_dim, |
45 | device=self.token_embedding.weight.device, | 47 | device=self.token_embedding.weight.device, |
46 | dtype=self.token_embedding.weight.dtype | 48 | dtype=self.token_embedding.weight.dtype |
47 | ) | 49 | ) |
48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
50 | 51 | ||
51 | def resize(self, size: int): | 52 | def resize(self, size: int): |
52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
@@ -74,9 +74,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
74 | ) | 74 | ) |
75 | 75 | ||
76 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 76 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
77 | |||
78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 77 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
79 | self.temp_token_embedding.weight.data[token_ids] = initializer | 78 | |
79 | self.temp_token_embedding = resize_embedding( | ||
80 | self.temp_token_embedding, | ||
81 | self.temp_token_ids.shape[0], | ||
82 | self.initializer_factor | ||
83 | ) | ||
84 | |||
85 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | ||
86 | self.temp_token_embedding.weight.data[mask] = initializer | ||
87 | self.token_embedding.weight.data[token_ids] = initializer | ||
80 | 88 | ||
81 | def load_embed(self, input_ids: list[int], filename: Path): | 89 | def load_embed(self, input_ids: list[int], filename: Path): |
82 | with safe_open(filename, framework="pt", device="cpu") as file: | 90 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -86,17 +94,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
86 | save_file({"embed": self.get_embed(input_ids)}, filename) | 94 | save_file({"embed": self.get_embed(input_ids)}, filename) |
87 | 95 | ||
88 | def persist(self): | 96 | def persist(self): |
89 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 97 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] |
90 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 98 | self.init_temp_embeddings() |
91 | 99 | ||
92 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 100 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
93 | if isinstance(input_ids, list): | 101 | if isinstance(input_ids, list): |
94 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 102 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
95 | 103 | ||
104 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
105 | |||
96 | embeds = self.token_embedding(input_ids) | 106 | embeds = self.token_embedding(input_ids) |
97 | 107 | ||
98 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 108 | embeds_mask = torch.isin(input_ids, all_temp_token_ids) |
99 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 109 | temp_token_ids = input_ids[embeds_mask] |
110 | |||
111 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
112 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
113 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
114 | |||
115 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | ||
100 | 116 | ||
101 | return embeds | 117 | return embeds |
102 | 118 | ||
diff --git a/train_ti.py b/train_ti.py index 9ae8d1b..e4fd464 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -588,14 +588,6 @@ def main(): | |||
588 | unet.enable_gradient_checkpointing() | 588 | unet.enable_gradient_checkpointing() |
589 | text_encoder.gradient_checkpointing_enable() | 589 | text_encoder.gradient_checkpointing_enable() |
590 | 590 | ||
591 | if args.embeddings_dir is not None: | ||
592 | embeddings_dir = Path(args.embeddings_dir) | ||
593 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
594 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
595 | |||
596 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
597 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
598 | |||
599 | if len(args.alias_tokens) != 0: | 591 | if len(args.alias_tokens) != 0: |
600 | alias_placeholder_tokens = args.alias_tokens[::2] | 592 | alias_placeholder_tokens = args.alias_tokens[::2] |
601 | alias_initializer_tokens = args.alias_tokens[1::2] | 593 | alias_initializer_tokens = args.alias_tokens[1::2] |
@@ -609,6 +601,14 @@ def main(): | |||
609 | embeddings.persist() | 601 | embeddings.persist() |
610 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 602 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
611 | 603 | ||
604 | if args.embeddings_dir is not None: | ||
605 | embeddings_dir = Path(args.embeddings_dir) | ||
606 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
607 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
608 | |||
609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
611 | |||
612 | if args.scale_lr: | 612 | if args.scale_lr: |
613 | args.learning_rate = ( | 613 | args.learning_rate = ( |
614 | args.learning_rate * args.gradient_accumulation_steps * | 614 | args.learning_rate * args.gradient_accumulation_steps * |