summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py53
-rw-r--r--train_ti.py24
-rw-r--r--training/strategy/ti.py30
3 files changed, 44 insertions, 63 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9abd1bb..88e0cc0 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
31 return new_embedding 31 return new_embedding
32 32
33 33
34class OverlayLinear(nn.Module):
35 def __init__(self, in_features, out_features, rank=4):
36 super().__init__()
37
38 if rank > min(in_features, out_features):
39 raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}")
40
41 self.rank = rank
42 self.down = nn.Linear(in_features, rank, bias=False)
43 self.up = nn.Linear(rank, out_features, bias=False)
44 self.reset()
45
46 def reset(self):
47 nn.init.normal_(self.down.weight, std=1 / self.rank)
48 nn.init.zeros_(self.up.weight)
49
50 def forward(self, hidden_states):
51 orig_dtype = hidden_states.dtype
52 dtype = self.down.weight.dtype
53
54 down_hidden_states = self.down(hidden_states.to(dtype))
55 up_hidden_states = self.up(down_hidden_states)
56
57 return up_hidden_states.to(orig_dtype)
58
59
34class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 60class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
35 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): 61 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128):
36 super().__init__(config) 62 super().__init__(config)
37 63
38 self.token_embedding = embeddings.token_embedding 64 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 65 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 66 self.initializer_factor = config.initializer_factor
41 67
42 self.temp_token_embedding = nn.Embedding( 68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
43 self.token_embedding.num_embeddings,
44 self.token_embedding.embedding_dim,
45 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype
47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 69 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 70
71 def reset_overlay(self):
72 self.overlay.reset()
73
51 def resize(self, size: int): 74 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 75 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 76
55 def add_embed( 77 def add_embed(
@@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 initializer = self.get_embed(initializer) 96 initializer = self.get_embed(initializer)
75 97
76 initializer = initializer.to( 98 initializer = initializer.to(
77 device=self.temp_token_embedding.weight.device, 99 device=self.token_embedding.weight.device,
78 dtype=self.temp_token_embedding.weight.dtype, 100 dtype=self.token_embedding.weight.dtype,
79 ) 101 )
80 102
81 if initializer_noise != 0: 103 if initializer_noise != 0:
@@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
84 token_ids = torch.tensor(token_ids, dtype=torch.long) 106 token_ids = torch.tensor(token_ids, dtype=torch.long)
85 107
86 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 108 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
87 self.temp_token_embedding.weight.data[token_ids] = initializer
88 self.token_embedding.weight.data[token_ids] = initializer 109 self.token_embedding.weight.data[token_ids] = initializer
89 110
90 def load_embed(self, input_ids: list[int], filename: Path): 111 def load_embed(self, input_ids: list[int], filename: Path):
@@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
95 save_file({"embed": self.get_embed(input_ids)}, filename) 116 save_file({"embed": self.get_embed(input_ids)}, filename)
96 117
97 def persist(self): 118 def persist(self):
98 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 119 self.token_embedding.weight.data[self.temp_token_ids] += self.overlay(
120 self.token_embedding.weight.data[self.temp_token_ids]
121 )
122 self.overlay.reset()
99 self.temp_token_ids = torch.tensor([], dtype=torch.long) 123 self.temp_token_ids = torch.tensor([], dtype=torch.long)
100 124
101 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 125 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
@@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
103 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 127 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
104 128
105 embeds = self.token_embedding(input_ids) 129 embeds = self.token_embedding(input_ids)
106
107 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 130 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
108 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 131 embeds[mask] += self.overlay(embeds[mask])
109 132
110 return embeds 133 return embeds
111 134
diff --git a/train_ti.py b/train_ti.py
index 5482326..0ce0056 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=0, 356 default=1e-2,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,23 +451,6 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay",
455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float,
468 help="Embedding decay factor."
469 )
470 parser.add_argument(
471 "--noise_timesteps", 454 "--noise_timesteps",
472 type=int, 455 type=int,
473 default=1000, 456 default=1000,
@@ -732,9 +715,6 @@ def main():
732 sample_scheduler=sample_scheduler, 715 sample_scheduler=sample_scheduler,
733 checkpoint_output_dir=checkpoint_output_dir, 716 checkpoint_output_dir=checkpoint_output_dir,
734 gradient_checkpointing=args.gradient_checkpointing, 717 gradient_checkpointing=args.gradient_checkpointing,
735 use_emb_decay=args.use_emb_decay,
736 emb_decay_target=args.emb_decay_target,
737 emb_decay=args.emb_decay,
738 use_ema=args.use_ema, 718 use_ema=args.use_ema,
739 ema_inv_gamma=args.ema_inv_gamma, 719 ema_inv_gamma=args.ema_inv_gamma,
740 ema_power=args.ema_power, 720 ema_power=args.ema_power,
@@ -800,7 +780,7 @@ def main():
800 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
801 781
802 optimizer = create_optimizer( 782 optimizer = create_optimizer(
803 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 783 text_encoder.text_model.embeddings.overlay.parameters(),
804 lr=args.learning_rate, 784 lr=args.learning_rate,
805 ) 785 )
806 786
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..19b8d25 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks(
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False, 34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False, 35 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0, 36 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1, 37 ema_power: int = 1,
@@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks(
73 70
74 if use_ema: 71 if use_ema:
75 ema_embeddings = EMAModel( 72 ema_embeddings = EMAModel(
76 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 73 text_encoder.text_model.embeddings.overlay.parameters(),
77 inv_gamma=ema_inv_gamma, 74 inv_gamma=ema_inv_gamma,
78 power=ema_power, 75 power=ema_power,
79 max_value=ema_max_decay, 76 max_value=ema_max_decay,
@@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks(
85 def ema_context(): 82 def ema_context():
86 if ema_embeddings is not None: 83 if ema_embeddings is not None:
87 return ema_embeddings.apply_temporary( 84 return ema_embeddings.apply_temporary(
88 text_encoder.text_model.embeddings.temp_token_embedding.parameters() 85 text_encoder.text_model.embeddings.overlay.parameters()
89 ) 86 )
90 else: 87 else:
91 return nullcontext() 88 return nullcontext()
92 89
93 def on_accum_model(): 90 def on_accum_model():
94 return text_encoder.text_model.embeddings.temp_token_embedding 91 return text_encoder.text_model.embeddings.overlay
95 92
96 @contextmanager 93 @contextmanager
97 def on_train(epoch: int): 94 def on_train(epoch: int):
@@ -106,27 +103,9 @@ def textual_inversion_strategy_callbacks(
106 yield 103 yield
107 104
108 @torch.no_grad() 105 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 106 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None: 107 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 108 ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters())
118
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130 109
131 def on_log(): 110 def on_log():
132 if ema_embeddings is not None: 111 if ema_embeddings is not None:
@@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks(
171 on_accum_model=on_accum_model, 150 on_accum_model=on_accum_model,
172 on_train=on_train, 151 on_train=on_train,
173 on_eval=on_eval, 152 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize, 153 on_after_optimize=on_after_optimize,
176 on_log=on_log, 154 on_log=on_log,
177 on_checkpoint=on_checkpoint, 155 on_checkpoint=on_checkpoint,