From 5f0f6aac63373780132fced5ad8fd6216097f5ae Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 31 Mar 2023 21:05:11 +0200 Subject: Update --- models/clip/embeddings.py | 33 +++++++++------------------------ train_dreambooth.py | 7 +++++++ train_lora.py | 7 +++++++ train_ti.py | 7 +++++++ training/optimization.py | 2 ++ 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index e8cc865..4166dc6 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.init_temp_embeddings() - def init_temp_embeddings(self): self.temp_token_embedding = nn.Embedding( - 0, + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) + self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): @@ -74,16 +74,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): ) token_ids = torch.tensor(token_ids, dtype=torch.long) - self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - - self.temp_token_embedding = resize_embedding( - self.temp_token_embedding, - self.temp_token_ids.shape[0], - self.initializer_factor - ) - mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) - self.temp_token_embedding.weight.data[mask] = initializer + self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) + self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -94,25 +87,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] - self.init_temp_embeddings() + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - all_temp_token_ids = self.temp_token_ids.to(input_ids.device) - embeds = self.token_embedding(input_ids) - embeds_mask = torch.isin(input_ids, all_temp_token_ids) - temp_token_ids = input_ids[embeds_mask] - - temp_token_ids = temp_token_ids.unsqueeze(1) - all_temp_token_ids = all_temp_token_ids.unsqueeze(0) - temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - - embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) + embeds[mask] = self.temp_token_embedding(input_ids)[mask] return embeds diff --git a/train_dreambooth.py b/train_dreambooth.py index 0634376..2c884d2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -239,6 +239,12 @@ def parse_args(): default=10, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_mid_point", + type=float, + default=0.3, + help="OneCycle schedule mid point." + ) parser.add_argument( "--lr_cycles", type=int, @@ -634,6 +640,7 @@ def main(): end_lr=1e2, train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, ) metrics = trainer( diff --git a/train_lora.py b/train_lora.py index d89b18d..59beb09 100644 --- a/train_lora.py +++ b/train_lora.py @@ -271,6 +271,12 @@ def parse_args(): default=10, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_mid_point", + type=float, + default=0.3, + help="OneCycle schedule mid point." + ) parser.add_argument( "--lr_cycles", type=int, @@ -662,6 +668,7 @@ def main(): end_lr=1e2, train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, ) metrics = trainer( diff --git a/train_ti.py b/train_ti.py index b182a72..83043ad 100644 --- a/train_ti.py +++ b/train_ti.py @@ -257,6 +257,12 @@ def parse_args(): default=10, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_mid_point", + type=float, + default=0.3, + help="OneCycle schedule mid point." + ) parser.add_argument( "--lr_cycles", type=int, @@ -790,6 +796,7 @@ def main(): end_lr=1e3, train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, ) metrics = trainer( diff --git a/training/optimization.py b/training/optimization.py index 7d8d55a..59ca950 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -102,6 +102,7 @@ def get_scheduler( num_training_steps_per_epoch: int, gradient_accumulation_steps: int = 1, min_lr: float = 0.04, + mid_point: float = 0.3, warmup_func: Literal["cos", "linear"] = "cos", annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, @@ -126,6 +127,7 @@ def get_scheduler( warmup_exp=warmup_exp, annealing_exp=annealing_exp, min_lr=min_lr, + mid_point=mid_point, ) elif id == "exponential_growth": if cycles is None: -- cgit v1.2.3-70-g09d2