summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py34
-rw-r--r--train_ti.py11
2 files changed, 19 insertions, 26 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2b315c4..2d60c28 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,24 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.num_permanent_embeddings = self.token_embedding.num_embeddings
42 self.init_temp_embeddings()
43 41
44 def init_temp_embeddings(self):
45 self.temp_token_embedding = nn.Embedding( 42 self.temp_token_embedding = nn.Embedding(
46 0, 43 self.token_embedding.num_embeddings,
47 self.token_embedding.embedding_dim, 44 self.token_embedding.embedding_dim,
48 device=self.token_embedding.weight.device, 45 device=self.token_embedding.weight.device,
49 dtype=self.token_embedding.weight.dtype 46 dtype=self.token_embedding.weight.dtype
50 ) 47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
51 self.temp_token_ids = torch.tensor([], dtype=torch.long) 49 self.temp_token_ids = torch.tensor([], dtype=torch.long)
52 50
53 def resize(self, size: int): 51 def resize(self, size: int):
54 self.temp_token_embedding = resize_embedding( 52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
55 self.temp_token_embedding,
56 size - self.num_permanent_embeddings,
57 self.initializer_factor
58 )
59 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
60 54
61 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -75,15 +69,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
75 initializer = self.get_embed(initializer) 69 initializer = self.get_embed(initializer)
76 70
77 initializer = initializer.to( 71 initializer = initializer.to(
78 device=self.token_embedding.weight.device, 72 device=self.temp_token_embedding.weight.device,
79 dtype=self.token_embedding.weight.dtype, 73 dtype=self.temp_token_embedding.weight.dtype,
80 ) 74 )
81 75
82 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
83 77
84 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
85 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) 79 self.temp_token_embedding.weight.data[token_ids] = initializer
86 self.temp_token_embedding.weight.data[mask] = initializer
87 80
88 def load_embed(self, input_ids: list[int], filename: Path): 81 def load_embed(self, input_ids: list[int], filename: Path):
89 with safe_open(filename, framework="pt", device="cpu") as file: 82 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -94,25 +87,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
94 87
95 def persist(self): 88 def persist(self):
96 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 89 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
97 self.num_permanent_embeddings = self.token_embedding.num_embeddings 90 self.temp_token_ids = torch.tensor([], dtype=torch.long)
98 self.init_temp_embeddings()
99 91
100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 92 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
101 if isinstance(input_ids, list): 93 if isinstance(input_ids, list):
102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 94 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
103 95
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
106 embeds = self.token_embedding(input_ids) 96 embeds = self.token_embedding(input_ids)
107 97
108 embeds_mask = torch.isin(input_ids, all_temp_token_ids) 98 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
109 temp_token_ids = input_ids[embeds_mask] 99 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
116 100
117 return embeds 101 return embeds
118 102
diff --git a/train_ti.py b/train_ti.py
index ef39c38..9ae8d1b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -155,7 +155,7 @@ def parse_args():
155 parser.add_argument( 155 parser.add_argument(
156 "--num_buckets", 156 "--num_buckets",
157 type=int, 157 type=int,
158 default=2, 158 default=0,
159 help="Number of aspect ratio buckets in either direction.", 159 help="Number of aspect ratio buckets in either direction.",
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
@@ -507,9 +507,18 @@ def parse_args():
507 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 507 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
508 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 508 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
509 509
510 if args.alias_tokens is None:
511 args.alias_tokens = []
512
510 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 513 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
511 raise ValueError("--alias_tokens must be a list with an even number of items") 514 raise ValueError("--alias_tokens must be a list with an even number of items")
512 515
516 args.alias_tokens += [
517 item
518 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
519 for item in pair
520 ]
521
513 if args.sequential: 522 if args.sequential:
514 if isinstance(args.train_data_template, str): 523 if isinstance(args.train_data_template, str):
515 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 524 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)