From e454b7e7df13cf6ce7b96b7dcc107533edf83f6f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 12 Jan 2023 08:51:17 +0100 Subject: Fixed TI decay --- models/clip/embeddings.py | 10 +++++++++- train_ti.py | 11 +++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9d8f770..46b414b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -3,6 +3,7 @@ from pathlib import Path import torch import torch.nn as nn +import torch.nn.functional as F from safetensors import safe_open from safetensors.torch import save_file @@ -45,7 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) - self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) + self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): @@ -98,6 +99,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds + def normalize(self, lambda_: float = 1.0): + w = self.temp_token_embedding.weight + pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) + w[self.temp_token_ids] = F.normalize( + w[self.temp_token_ids, :], dim=-1 + ) * (pre_norm + lambda_ * (0.4 - pre_norm)) + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/train_ti.py b/train_ti.py index 2c5037f..890c465 100644 --- a/train_ti.py +++ b/train_ti.py @@ -7,7 +7,6 @@ from pathlib import Path from contextlib import contextmanager, nullcontext import torch -import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator @@ -166,7 +165,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0.1, + default=0, help="Tag dropout probability.", ) parser.add_argument( @@ -177,7 +176,7 @@ def parse_args(): parser.add_argument( "--vector_dropout", type=int, - default=0.1, + default=0, help="Vector dropout probability.", ) parser.add_argument( @@ -869,11 +868,7 @@ def main(): @torch.no_grad() def on_clip(lr): - embeddings = text_encoder.text_model.embeddings.temp_token_embedding - - pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) - lambda_ = min(1.0, 100 * lr) - embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) + text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) loop = partial( loss_step, -- cgit v1.2.3-70-g09d2