summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py33
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_lora.py7
-rw-r--r--train_ti.py7
-rw-r--r--training/optimization.py2
5 files changed, 32 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index e8cc865..4166dc6 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.init_temp_embeddings()
42 41
43 def init_temp_embeddings(self):
44 self.temp_token_embedding = nn.Embedding( 42 self.temp_token_embedding = nn.Embedding(
45 0, 43 self.token_embedding.num_embeddings,
46 self.token_embedding.embedding_dim, 44 self.token_embedding.embedding_dim,
47 device=self.token_embedding.weight.device, 45 device=self.token_embedding.weight.device,
48 dtype=self.token_embedding.weight.dtype 46 dtype=self.token_embedding.weight.dtype
49 ) 47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
50 self.temp_token_ids = torch.tensor([], dtype=torch.long) 49 self.temp_token_ids = torch.tensor([], dtype=torch.long)
51 50
52 def resize(self, size: int): 51 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,16 +74,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 ) 74 )
75 75
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
77 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
78
79 self.temp_token_embedding = resize_embedding(
80 self.temp_token_embedding,
81 self.temp_token_ids.shape[0],
82 self.initializer_factor
83 )
84 77
85 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) 78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
86 self.temp_token_embedding.weight.data[mask] = initializer 79 self.temp_token_embedding.weight.data[token_ids] = initializer
87 self.token_embedding.weight.data[token_ids] = initializer 80 self.token_embedding.weight.data[token_ids] = initializer
88 81
89 def load_embed(self, input_ids: list[int], filename: Path): 82 def load_embed(self, input_ids: list[int], filename: Path):
@@ -94,25 +87,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
94 save_file({"embed": self.get_embed(input_ids)}, filename) 87 save_file({"embed": self.get_embed(input_ids)}, filename)
95 88
96 def persist(self): 89 def persist(self):
97 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] 90 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
98 self.init_temp_embeddings() 91 self.temp_token_ids = torch.tensor([], dtype=torch.long)
99 92
100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 93 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
101 if isinstance(input_ids, list): 94 if isinstance(input_ids, list):
102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 95 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
103 96
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
106 embeds = self.token_embedding(input_ids) 97 embeds = self.token_embedding(input_ids)
107 98
108 embeds_mask = torch.isin(input_ids, all_temp_token_ids) 99 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
109 temp_token_ids = input_ids[embeds_mask] 100 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
116 101
117 return embeds 102 return embeds
118 103
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0634376..2c884d2 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -240,6 +240,12 @@ def parse_args():
240 help="Number of steps for the warmup in the lr scheduler." 240 help="Number of steps for the warmup in the lr scheduler."
241 ) 241 )
242 parser.add_argument( 242 parser.add_argument(
243 "--lr_mid_point",
244 type=float,
245 default=0.3,
246 help="OneCycle schedule mid point."
247 )
248 parser.add_argument(
243 "--lr_cycles", 249 "--lr_cycles",
244 type=int, 250 type=int,
245 default=None, 251 default=None,
@@ -634,6 +640,7 @@ def main():
634 end_lr=1e2, 640 end_lr=1e2,
635 train_epochs=num_train_epochs, 641 train_epochs=num_train_epochs,
636 warmup_epochs=args.lr_warmup_epochs, 642 warmup_epochs=args.lr_warmup_epochs,
643 mid_point=args.lr_mid_point,
637 ) 644 )
638 645
639 metrics = trainer( 646 metrics = trainer(
diff --git a/train_lora.py b/train_lora.py
index d89b18d..59beb09 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -272,6 +272,12 @@ def parse_args():
272 help="Number of steps for the warmup in the lr scheduler." 272 help="Number of steps for the warmup in the lr scheduler."
273 ) 273 )
274 parser.add_argument( 274 parser.add_argument(
275 "--lr_mid_point",
276 type=float,
277 default=0.3,
278 help="OneCycle schedule mid point."
279 )
280 parser.add_argument(
275 "--lr_cycles", 281 "--lr_cycles",
276 type=int, 282 type=int,
277 default=None, 283 default=None,
@@ -662,6 +668,7 @@ def main():
662 end_lr=1e2, 668 end_lr=1e2,
663 train_epochs=num_train_epochs, 669 train_epochs=num_train_epochs,
664 warmup_epochs=args.lr_warmup_epochs, 670 warmup_epochs=args.lr_warmup_epochs,
671 mid_point=args.lr_mid_point,
665 ) 672 )
666 673
667 metrics = trainer( 674 metrics = trainer(
diff --git a/train_ti.py b/train_ti.py
index b182a72..83043ad 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -258,6 +258,12 @@ def parse_args():
258 help="Number of steps for the warmup in the lr scheduler." 258 help="Number of steps for the warmup in the lr scheduler."
259 ) 259 )
260 parser.add_argument( 260 parser.add_argument(
261 "--lr_mid_point",
262 type=float,
263 default=0.3,
264 help="OneCycle schedule mid point."
265 )
266 parser.add_argument(
261 "--lr_cycles", 267 "--lr_cycles",
262 type=int, 268 type=int,
263 default=None, 269 default=None,
@@ -790,6 +796,7 @@ def main():
790 end_lr=1e3, 796 end_lr=1e3,
791 train_epochs=num_train_epochs, 797 train_epochs=num_train_epochs,
792 warmup_epochs=args.lr_warmup_epochs, 798 warmup_epochs=args.lr_warmup_epochs,
799 mid_point=args.lr_mid_point,
793 ) 800 )
794 801
795 metrics = trainer( 802 metrics = trainer(
diff --git a/training/optimization.py b/training/optimization.py
index 7d8d55a..59ca950 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -102,6 +102,7 @@ def get_scheduler(
102 num_training_steps_per_epoch: int, 102 num_training_steps_per_epoch: int,
103 gradient_accumulation_steps: int = 1, 103 gradient_accumulation_steps: int = 1,
104 min_lr: float = 0.04, 104 min_lr: float = 0.04,
105 mid_point: float = 0.3,
105 warmup_func: Literal["cos", "linear"] = "cos", 106 warmup_func: Literal["cos", "linear"] = "cos",
106 annealing_func: Literal["cos", "half_cos", "linear"] = "cos", 107 annealing_func: Literal["cos", "half_cos", "linear"] = "cos",
107 warmup_exp: int = 1, 108 warmup_exp: int = 1,
@@ -126,6 +127,7 @@ def get_scheduler(
126 warmup_exp=warmup_exp, 127 warmup_exp=warmup_exp,
127 annealing_exp=annealing_exp, 128 annealing_exp=annealing_exp,
128 min_lr=min_lr, 129 min_lr=min_lr,
130 mid_point=mid_point,
129 ) 131 )
130 elif id == "exponential_growth": 132 elif id == "exponential_growth":
131 if cycles is None: 133 if cycles is None: