From 30b557c8e1f03b4748ac3efca599ff51d66561cb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Apr 2023 07:30:43 +0200 Subject: TI: Bring back old embedding decay --- models/clip/embeddings.py | 15 +++++++-------- models/sparse.py | 14 ++++++++------ train_ti.py | 24 +++++++++++++++++++----- training/functional.py | 4 ++-- training/strategy/ti.py | 22 +++++++++++++++++++++- 5 files changed, 57 insertions(+), 22 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index a356434..63a141f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): super().__init__(config) self.token_embedding = embeddings.token_embedding @@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype, ) - self.alpha = alpha def resize(self, size: int): self.token_override_embedding.resize(size) @@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.weight.data[token_ids] = initializer - self.token_override_embedding.set(token_ids) + self.token_override_embedding.set(token_ids, initializer) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): embs, mask = self.token_override_embedding(input_ids) if embs is not None: input_ids = input_ids[mask] - self.token_embedding.weight.data[input_ids] += self.alpha * embs - self.token_override_embedding.unset(input_ids) + self.token_embedding.weight.data[input_ids] = embs + self.token_override_embedding.unset(input_ids) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): @@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): embs = self.token_embedding(input_ids) embs_override, mask = self.token_override_embedding(input_ids) if embs_override is not None: - embs[mask] += self.alpha * embs_override + embs[mask] = embs_override return embs @@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) +def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): self.params = nn.ParameterList() self.mapping = torch.zeros(0, device=device, dtype=torch.long) - def forward(self, input_ids: Optional[torch.LongTensor] = None): - if input_ids is None: - input_ids = torch.arange(self.mapping.shape[0]) - + def forward(self, input_ids: torch.LongTensor): ids = self.mapping[input_ids.to(self.mapping.device)] mask = ~(ids == -1) @@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): else: return [self.set(id) for id in input_ids] + if tensor is None: + tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + + if tensor.shape[-1] != self.embedding_dim: + raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") + id = self.mapping[input_ids] if id == -1: @@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): self.mapping[input_ids] = id self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) - self.params[id] = tensor if tensor is not None else torch.zeros( - self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + self.params[id] = tensor def unset(self, input_ids: torch.LongTensor): self.mapping[input_ids] = -1 diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( @@ -451,10 +451,21 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--emb_alpha", - default=1.0, + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e+2, type=float, - help="Embedding alpha." + help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -600,7 +611,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, args.emb_alpha) + args.pretrained_model_name_or_path) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -744,6 +755,9 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, diff --git a/training/functional.py b/training/functional.py index 1d8e2ee..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_alpha) + embeddings = patch_managed_embeddings(text_encoder) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 95128da..9df160a 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_after_optimize(zero_ids, lr: float): + def on_before_optimize(lr: float, epoch: int): + if use_emb_decay: + return torch.stack([ + p + for p in text_encoder.text_model.embeddings.token_override_embedding.params + if p.grad is not None + ]) + + @torch.no_grad() + def on_after_optimize(w, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) + if use_emb_decay: + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2