From e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 3 Apr 2023 18:52:30 +0200
Subject: TI: Delta learning

---
 models/clip/embeddings.py | 50 +++++++++++++++++++++++++++++++----------------
 train_ti.py               | 37 +++++++++++------------------------
 training/functional.py    |  4 ++--
 training/strategy/ti.py   | 23 ----------------------
 4 files changed, 46 insertions(+), 68 deletions(-)

diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1e21965..d8343a0 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -12,7 +12,7 @@ from transformers.models.clip import CLIPTextConfig
 from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
 
 
-def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding:
+def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
     old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
 
     if old_num_embeddings == new_num_embeddings:
@@ -26,13 +26,16 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
         device=old_embedding.weight.device,
         dtype=old_embedding.weight.dtype
     )
-    new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
+    if initializer_factor is not None:
+        new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
+    else:
+        nn.init.zeros_(new_embedding.weight.data)
     new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
     return new_embedding
 
 
 class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
-    def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4):
+    def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0):
         super().__init__(config)
 
         self.token_embedding = embeddings.token_embedding
@@ -40,17 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         self.initializer_factor = config.initializer_factor
         self.alpha = alpha
 
-        self.temp_token_embedding = nn.Embedding(
-            self.token_embedding.num_embeddings,
-            self.token_embedding.embedding_dim,
-            device=self.token_embedding.weight.device,
-            dtype=self.token_embedding.weight.dtype
-        )
-        self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
+        self.temp_token_embedding = nn.ParameterList()
         self.temp_token_ids = torch.tensor([], dtype=torch.long)
 
     def resize(self, size: int):
-        self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
+        for _ in range(len(self.temp_token_embedding), size):
+            self.temp_token_embedding.append(torch.zeros(
+                self.token_embedding.embedding_dim,
+                device=self.token_embedding.weight.device,
+                dtype=self.token_embedding.weight.dtype,
+            ))
         self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
 
     def add_embed(
@@ -85,7 +87,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         token_ids = torch.tensor(token_ids, dtype=torch.long)
 
         self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
-        self.temp_token_embedding.weight.data[token_ids] = initializer
         self.token_embedding.weight.data[token_ids] = initializer
 
     def load_embed(self, input_ids: list[int], filename: Path):
@@ -96,16 +97,31 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         save_file({"embed": self.get_embed(input_ids)}, filename)
 
     def persist(self):
-        self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
+        for id, emb in zip(self.temp_token_ids, self.temp_token_embedding):
+            self.token_embedding.weight.data[id] += self.alpha * emb
+            nn.init.zeros_(emb)
         self.temp_token_ids = torch.tensor([], dtype=torch.long)
 
     def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
         if isinstance(input_ids, list):
             input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
 
+        all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
+
         embeds = self.token_embedding(input_ids)
-        mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
-        embeds[mask] = self.temp_token_embedding(input_ids[mask])
+        mask = torch.isin(input_ids, all_temp_token_ids)
+        temp_token_ids = input_ids[mask]
+
+        temp_token_ids = temp_token_ids.unsqueeze(1)
+        all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
+        temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
+
+        if len(temp_token_ids):
+            embeds_override = torch.stack([
+                self.temp_token_embedding[id]
+                for id in temp_token_ids
+            ])
+            embeds[mask] += self.alpha * embeds_override
 
         return embeds
 
@@ -129,7 +145,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         return embeddings
 
 
-def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
-    text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
+def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings:
+    text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha)
     text_encoder.text_model.embeddings = text_embeddings
     return text_embeddings
diff --git a/train_ti.py b/train_ti.py
index 8dde1ba..0ad7574 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
     parser.add_argument(
         "--adam_weight_decay",
         type=float,
-        default=0,
+        default=1e-2,
         help="Weight decay to use."
     )
     parser.add_argument(
@@ -451,21 +451,10 @@ def parse_args():
         help="The weight of prior preservation loss."
     )
     parser.add_argument(
-        "--use_emb_decay",
-        action="store_true",
-        help="Whether to use embedding decay."
-    )
-    parser.add_argument(
-        "--emb_decay_target",
-        default=0.4,
-        type=float,
-        help="Embedding decay target."
-    )
-    parser.add_argument(
-        "--emb_decay",
-        default=1e2,
+        "--emb_alpha",
+        default=1.0,
         type=float,
-        help="Embedding decay factor."
+        help="Embedding alpha."
     )
     parser.add_argument(
         "--noise_timesteps",
@@ -567,16 +556,16 @@ def parse_args():
         raise ValueError("You must specify --output_dir")
 
     if args.adam_beta1 is None:
-        if args.optimizer in ('adam', 'adam8bit'):
-            args.adam_beta1 = 0.9
-        elif args.optimizer == 'lion':
+        if args.optimizer == 'lion':
             args.adam_beta1 = 0.95
+        else:
+            args.adam_beta1 = 0.9
 
     if args.adam_beta2 is None:
-        if args.optimizer in ('adam', 'adam8bit'):
-            args.adam_beta2 = 0.999
-        elif args.optimizer == 'lion':
+        if args.optimizer == 'lion':
             args.adam_beta2 = 0.98
+        else:
+            args.adam_beta2 = 0.999
 
     return args
 
@@ -611,7 +600,7 @@ def main():
     save_args(output_dir, args)
 
     tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
-        args.pretrained_model_name_or_path)
+        args.pretrained_model_name_or_path, args.emb_alpha)
 
     tokenizer.set_use_vector_shuffle(args.vector_shuffle)
     tokenizer.set_dropout(args.vector_dropout)
@@ -755,10 +744,6 @@ def main():
         tokenizer=tokenizer,
         sample_scheduler=sample_scheduler,
         checkpoint_output_dir=checkpoint_output_dir,
-        gradient_checkpointing=args.gradient_checkpointing,
-        use_emb_decay=args.use_emb_decay,
-        emb_decay_target=args.emb_decay_target,
-        emb_decay=args.emb_decay,
         use_ema=args.use_ema,
         ema_inv_gamma=args.ema_inv_gamma,
         ema_power=args.ema_power,
diff --git a/training/functional.py b/training/functional.py
index 96ecbc1..1d8e2ee 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols):
     return grid
 
 
-def get_models(pretrained_model_name_or_path: str):
+def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0):
     tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
     text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
     vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str):
     sample_scheduler = UniPCMultistepScheduler.from_pretrained(
         pretrained_model_name_or_path, subfolder='scheduler')
 
-    embeddings = patch_managed_embeddings(text_encoder)
+    embeddings = patch_managed_embeddings(text_encoder, emb_alpha)
 
     return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
 
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index c7520ed..16baa34 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks(
     seed: int,
     placeholder_tokens: list[str],
     placeholder_token_ids: list[list[int]],
-    gradient_checkpointing: bool = False,
-    use_emb_decay: bool = False,
-    emb_decay_target: float = 0.4,
-    emb_decay: float = 1e-2,
     use_ema: bool = False,
     ema_inv_gamma: float = 1.0,
     ema_power: int = 1,
@@ -105,29 +101,11 @@ def textual_inversion_strategy_callbacks(
         with ema_context():
             yield
 
-    @torch.no_grad()
-    def on_before_optimize(lr: float, epoch: int):
-        if use_emb_decay:
-            w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-            return torch.all(w.grad == 0, dim=1)
-
     @torch.no_grad()
     def on_after_optimize(zero_ids, lr: float):
         if ema_embeddings is not None:
             ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
 
-        if use_emb_decay:
-            lambda_ = emb_decay * lr
-
-            if lambda_ != 0:
-                w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-
-                mask = torch.ones(w.shape[0], dtype=torch.bool)
-                mask[zero_ids] = False
-
-                norm = w[mask, :].norm(dim=-1, keepdim=True)
-                w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
-
     def on_log():
         if ema_embeddings is not None:
             return {"ema_decay": ema_embeddings.decay}
@@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks(
         on_accum_model=on_accum_model,
         on_train=on_train,
         on_eval=on_eval,
-        on_before_optimize=on_before_optimize,
         on_after_optimize=on_after_optimize,
         on_log=on_log,
         on_checkpoint=on_checkpoint,
-- 
cgit v1.2.3-70-g09d2