summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py22
-rw-r--r--train_ti.py2
-rw-r--r--training/optimization.py6
3 files changed, 15 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f90e7c2..9c3a56b 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
56 if isinstance(token_ids, int): 56 if isinstance(token_ids, int):
57 token_ids = [token_ids] 57 token_ids = [token_ids]
58 58
59 if initializer is not None: 59 if initializer is None:
60 if isinstance(initializer, int): 60 initializer = token_ids
61 initializer = [initializer]
62 61
63 if isinstance(initializer, list): 62 if isinstance(initializer, int):
64 initializer = (initializer * len(token_ids))[:len(token_ids)] 63 initializer = [initializer]
65 64
66 with torch.no_grad(): 65 if isinstance(initializer, list):
67 initializer = self.get_embed(initializer) 66 initializer = (initializer * len(token_ids))[:len(token_ids)]
67
68 with torch.no_grad():
69 initializer = self.get_embed(initializer)
68 70
69 token_ids = torch.tensor(token_ids, dtype=torch.long) 71 token_ids = torch.tensor(token_ids, dtype=torch.long)
70 72
71 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
72 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
73 if initializer is not None: 75 dtype=self.temp_token_embedding.weight.dtype)
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype)
76 76
77 def load_embed(self, input_ids: list[int], filename: Path): 77 def load_embed(self, input_ids: list[int], filename: Path):
78 with safe_open(filename, framework="pt", device="cpu") as file: 78 with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/train_ti.py b/train_ti.py
index 775b918..870bd40 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -250,7 +250,7 @@ def parse_args():
250 parser.add_argument( 250 parser.add_argument(
251 "--lr_annealing_exp", 251 "--lr_annealing_exp",
252 type=int, 252 type=int,
253 default=1, 253 default=2,
254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
255 ) 255 )
256 parser.add_argument( 256 parser.add_argument(
diff --git a/training/optimization.py b/training/optimization.py
index a79944f..725599b 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -30,7 +30,7 @@ def get_one_cycle_schedule(
30 return min_lr + progress * (1 - min_lr) 30 return min_lr + progress * (1 - min_lr)
31 31
32 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) 32 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
33 lr = lr ** warmup_exp 33 lr = lr ** (warmup_exp - (warmup_exp - 1) * progress)
34 return min_lr + lr * (1 - min_lr) 34 return min_lr + lr * (1 - min_lr)
35 35
36 if annealing == "linear": 36 if annealing == "linear":
@@ -47,11 +47,11 @@ def get_one_cycle_schedule(
47 47
48 if annealing == "half_cos": 48 if annealing == "half_cos":
49 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) 49 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
50 lr = lr ** annealing_exp 50 lr = lr ** (annealing_exp - (annealing_exp - 1) * progress)
51 return lr 51 return lr
52 52
53 lr = 0.5 * (1.0 + math.cos(math.pi * progress)) 53 lr = 0.5 * (1.0 + math.cos(math.pi * progress))
54 lr = lr ** annealing_exp 54 lr = lr ** (annealing_exp - (annealing_exp - 1) * progress)
55 return lr 55 return lr
56 56
57 return LambdaLR(optimizer, lr_lambda, last_epoch) 57 return LambdaLR(optimizer, lr_lambda, last_epoch)