From c96073646bbb638d7d78fdd7d9fdeed08d1454b5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 16:30:36 +0200 Subject: Experimental: TI via LoRA --- models/clip/embeddings.py | 53 +++++++++++++++++++++++++++++++++-------------- train_ti.py | 24 ++------------------- training/strategy/ti.py | 30 ++++----------------------- 3 files changed, 44 insertions(+), 63 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9abd1bb..88e0cc0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi return new_embedding +class OverlayLinear(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.rank = rank + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + self.reset() + + def reset(self): + nn.init.normal_(self.down.weight, std=1 / self.rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + return up_hidden_states.to(orig_dtype) + + class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): super().__init__(config) self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.temp_token_embedding = nn.Embedding( - self.token_embedding.num_embeddings, - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype - ) - self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() + self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) self.temp_token_ids = torch.tensor([], dtype=torch.long) + def reset_overlay(self): + self.overlay.reset() + def resize(self, size: int): - self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = self.get_embed(initializer) initializer = initializer.to( - device=self.temp_token_embedding.weight.device, - dtype=self.temp_token_embedding.weight.dtype, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, ) if initializer_noise != 0: @@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( + self.token_embedding.weight.data[self.temp_token_ids] + ) + self.overlay.reset() self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): @@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] = self.temp_token_embedding(input_ids)[mask] + embeds[mask] += self.overlay(embeds[mask]) return embeds diff --git a/train_ti.py b/train_ti.py index 5482326..0ce0056 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -450,23 +450,6 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) - parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." - ) - parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." - ) - parser.add_argument( - "--emb_decay", - default=1e2, - type=float, - help="Embedding decay factor." - ) parser.add_argument( "--noise_timesteps", type=int, @@ -732,9 +715,6 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, @@ -800,7 +780,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), lr=args.learning_rate, ) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..19b8d25 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, - use_emb_decay: bool = False, - emb_decay_target: float = 0.4, - emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() + text_encoder.text_model.embeddings.overlay.parameters() ) else: return nullcontext() def on_accum_model(): - return text_encoder.text_model.embeddings.temp_token_embedding + return text_encoder.text_model.embeddings.overlay @contextmanager def on_train(epoch: int): @@ -105,28 +102,10 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): - if use_emb_decay: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - return torch.all(w.grad == 0, dim=1) - @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - if use_emb_decay: - lambda_ = emb_decay * lr - - if lambda_ != 0: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - - mask = torch.ones(w.shape[0], dtype=torch.bool) - mask[zero_ids] = False - - norm = w[mask, :].norm(dim=-1, keepdim=True) - w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) def on_log(): if ema_embeddings is not None: @@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, - on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2