diff options
| -rw-r--r-- | models/clip/embeddings.py | 33 | ||||
| -rw-r--r-- | train_dreambooth.py | 7 | ||||
| -rw-r--r-- | train_lora.py | 7 | ||||
| -rw-r--r-- | train_ti.py | 7 | ||||
| -rw-r--r-- | training/optimization.py | 2 |
5 files changed, 32 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index e8cc865..4166dc6 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
| 39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
| 40 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
| 41 | self.init_temp_embeddings() | ||
| 42 | 41 | ||
| 43 | def init_temp_embeddings(self): | ||
| 44 | self.temp_token_embedding = nn.Embedding( | 42 | self.temp_token_embedding = nn.Embedding( |
| 45 | 0, | 43 | self.token_embedding.num_embeddings, |
| 46 | self.token_embedding.embedding_dim, | 44 | self.token_embedding.embedding_dim, |
| 47 | device=self.token_embedding.weight.device, | 45 | device=self.token_embedding.weight.device, |
| 48 | dtype=self.token_embedding.weight.dtype | 46 | dtype=self.token_embedding.weight.dtype |
| 49 | ) | 47 | ) |
| 48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 51 | 50 | ||
| 52 | def resize(self, size: int): | 51 | def resize(self, size: int): |
| 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 54 | ||
| 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| @@ -74,16 +74,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | ) | 74 | ) |
| 75 | 75 | ||
| 76 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 76 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 77 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
| 78 | |||
| 79 | self.temp_token_embedding = resize_embedding( | ||
| 80 | self.temp_token_embedding, | ||
| 81 | self.temp_token_ids.shape[0], | ||
| 82 | self.initializer_factor | ||
| 83 | ) | ||
| 84 | 77 | ||
| 85 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 86 | self.temp_token_embedding.weight.data[mask] = initializer | 79 | self.temp_token_embedding.weight.data[token_ids] = initializer |
| 87 | self.token_embedding.weight.data[token_ids] = initializer | 80 | self.token_embedding.weight.data[token_ids] = initializer |
| 88 | 81 | ||
| 89 | def load_embed(self, input_ids: list[int], filename: Path): | 82 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -94,25 +87,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 94 | save_file({"embed": self.get_embed(input_ids)}, filename) | 87 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 95 | 88 | ||
| 96 | def persist(self): | 89 | def persist(self): |
| 97 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] | 90 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 98 | self.init_temp_embeddings() | 91 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 99 | 92 | ||
| 100 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 93 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 101 | if isinstance(input_ids, list): | 94 | if isinstance(input_ids, list): |
| 102 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 95 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 103 | 96 | ||
| 104 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
| 105 | |||
| 106 | embeds = self.token_embedding(input_ids) | 97 | embeds = self.token_embedding(input_ids) |
| 107 | 98 | ||
| 108 | embeds_mask = torch.isin(input_ids, all_temp_token_ids) | 99 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 109 | temp_token_ids = input_ids[embeds_mask] | 100 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
| 110 | |||
| 111 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
| 112 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
| 113 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
| 114 | |||
| 115 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | ||
| 116 | 101 | ||
| 117 | return embeds | 102 | return embeds |
| 118 | 103 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0634376..2c884d2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -240,6 +240,12 @@ def parse_args(): | |||
| 240 | help="Number of steps for the warmup in the lr scheduler." | 240 | help="Number of steps for the warmup in the lr scheduler." |
| 241 | ) | 241 | ) |
| 242 | parser.add_argument( | 242 | parser.add_argument( |
| 243 | "--lr_mid_point", | ||
| 244 | type=float, | ||
| 245 | default=0.3, | ||
| 246 | help="OneCycle schedule mid point." | ||
| 247 | ) | ||
| 248 | parser.add_argument( | ||
| 243 | "--lr_cycles", | 249 | "--lr_cycles", |
| 244 | type=int, | 250 | type=int, |
| 245 | default=None, | 251 | default=None, |
| @@ -634,6 +640,7 @@ def main(): | |||
| 634 | end_lr=1e2, | 640 | end_lr=1e2, |
| 635 | train_epochs=num_train_epochs, | 641 | train_epochs=num_train_epochs, |
| 636 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
| 643 | mid_point=args.lr_mid_point, | ||
| 637 | ) | 644 | ) |
| 638 | 645 | ||
| 639 | metrics = trainer( | 646 | metrics = trainer( |
diff --git a/train_lora.py b/train_lora.py index d89b18d..59beb09 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -272,6 +272,12 @@ def parse_args(): | |||
| 272 | help="Number of steps for the warmup in the lr scheduler." | 272 | help="Number of steps for the warmup in the lr scheduler." |
| 273 | ) | 273 | ) |
| 274 | parser.add_argument( | 274 | parser.add_argument( |
| 275 | "--lr_mid_point", | ||
| 276 | type=float, | ||
| 277 | default=0.3, | ||
| 278 | help="OneCycle schedule mid point." | ||
| 279 | ) | ||
| 280 | parser.add_argument( | ||
| 275 | "--lr_cycles", | 281 | "--lr_cycles", |
| 276 | type=int, | 282 | type=int, |
| 277 | default=None, | 283 | default=None, |
| @@ -662,6 +668,7 @@ def main(): | |||
| 662 | end_lr=1e2, | 668 | end_lr=1e2, |
| 663 | train_epochs=num_train_epochs, | 669 | train_epochs=num_train_epochs, |
| 664 | warmup_epochs=args.lr_warmup_epochs, | 670 | warmup_epochs=args.lr_warmup_epochs, |
| 671 | mid_point=args.lr_mid_point, | ||
| 665 | ) | 672 | ) |
| 666 | 673 | ||
| 667 | metrics = trainer( | 674 | metrics = trainer( |
diff --git a/train_ti.py b/train_ti.py index b182a72..83043ad 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -258,6 +258,12 @@ def parse_args(): | |||
| 258 | help="Number of steps for the warmup in the lr scheduler." | 258 | help="Number of steps for the warmup in the lr scheduler." |
| 259 | ) | 259 | ) |
| 260 | parser.add_argument( | 260 | parser.add_argument( |
| 261 | "--lr_mid_point", | ||
| 262 | type=float, | ||
| 263 | default=0.3, | ||
| 264 | help="OneCycle schedule mid point." | ||
| 265 | ) | ||
| 266 | parser.add_argument( | ||
| 261 | "--lr_cycles", | 267 | "--lr_cycles", |
| 262 | type=int, | 268 | type=int, |
| 263 | default=None, | 269 | default=None, |
| @@ -790,6 +796,7 @@ def main(): | |||
| 790 | end_lr=1e3, | 796 | end_lr=1e3, |
| 791 | train_epochs=num_train_epochs, | 797 | train_epochs=num_train_epochs, |
| 792 | warmup_epochs=args.lr_warmup_epochs, | 798 | warmup_epochs=args.lr_warmup_epochs, |
| 799 | mid_point=args.lr_mid_point, | ||
| 793 | ) | 800 | ) |
| 794 | 801 | ||
| 795 | metrics = trainer( | 802 | metrics = trainer( |
diff --git a/training/optimization.py b/training/optimization.py index 7d8d55a..59ca950 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -102,6 +102,7 @@ def get_scheduler( | |||
| 102 | num_training_steps_per_epoch: int, | 102 | num_training_steps_per_epoch: int, |
| 103 | gradient_accumulation_steps: int = 1, | 103 | gradient_accumulation_steps: int = 1, |
| 104 | min_lr: float = 0.04, | 104 | min_lr: float = 0.04, |
| 105 | mid_point: float = 0.3, | ||
| 105 | warmup_func: Literal["cos", "linear"] = "cos", | 106 | warmup_func: Literal["cos", "linear"] = "cos", |
| 106 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", | 107 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
| 107 | warmup_exp: int = 1, | 108 | warmup_exp: int = 1, |
| @@ -126,6 +127,7 @@ def get_scheduler( | |||
| 126 | warmup_exp=warmup_exp, | 127 | warmup_exp=warmup_exp, |
| 127 | annealing_exp=annealing_exp, | 128 | annealing_exp=annealing_exp, |
| 128 | min_lr=min_lr, | 129 | min_lr=min_lr, |
| 130 | mid_point=mid_point, | ||
| 129 | ) | 131 | ) |
| 130 | elif id == "exponential_growth": | 132 | elif id == "exponential_growth": |
| 131 | if cycles is None: | 133 | if cycles is None: |
