From 68164329b97f5cd79a56372dc6cace4b038afce8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 22:08:21 +0100 Subject: Update --- models/clip/embeddings.py | 22 +++++++++++----------- train_ti.py | 2 +- training/optimization.py | 6 +++--- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..9c3a56b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): if isinstance(token_ids, int): token_ids = [token_ids] - if initializer is not None: - if isinstance(initializer, int): - initializer = [initializer] + if initializer is None: + initializer = token_ids - if isinstance(initializer, list): - initializer = (initializer * len(token_ids))[:len(token_ids)] + if isinstance(initializer, int): + initializer = [initializer] - with torch.no_grad(): - initializer = self.get_embed(initializer) + if isinstance(initializer, list): + initializer = (initializer * len(token_ids))[:len(token_ids)] + + with torch.no_grad(): + initializer = self.get_embed(initializer) token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - - if initializer is not None: - self.temp_token_embedding.weight.data[token_ids] = initializer.to( - dtype=self.temp_token_embedding.weight.dtype) + self.temp_token_embedding.weight.data[token_ids] = initializer.to( + dtype=self.temp_token_embedding.weight.dtype) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: diff --git a/train_ti.py b/train_ti.py index 775b918..870bd40 100644 --- a/train_ti.py +++ b/train_ti.py @@ -250,7 +250,7 @@ def parse_args(): parser.add_argument( "--lr_annealing_exp", type=int, - default=1, + default=2, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( diff --git a/training/optimization.py b/training/optimization.py index a79944f..725599b 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -30,7 +30,7 @@ def get_one_cycle_schedule( return min_lr + progress * (1 - min_lr) lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) - lr = lr ** warmup_exp + lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) return min_lr + lr * (1 - min_lr) if annealing == "linear": @@ -47,11 +47,11 @@ def get_one_cycle_schedule( if annealing == "half_cos": lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) - lr = lr ** annealing_exp + lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) return lr lr = 0.5 * (1.0 + math.cos(math.pi * progress)) - lr = lr ** annealing_exp + lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) return lr return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2