From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- models/clip/embeddings.py | 6 +++++- models/clip/prompt.py | 38 -------------------------------------- models/clip/util.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 39 deletions(-) delete mode 100644 models/clip/prompt.py create mode 100644 models/clip/util.py (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor + self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() + self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, @@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, target: float = 0.4, lambda_: float = 1.0): + def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): + if target is None: + target = self.decay_target w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Union, Optional - -import torch - -from transformers import CLIPTokenizer, CLIPTextModel - - -class PromptProcessor(): - def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - - def get_input_ids(self, prompt: Union[str, list[str]]): - return self.tokenizer( - prompt, - padding="do_not_pad", - ).input_ids - - def unify_input_ids(self, input_ids: list[list[int]]): - return self.tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - pad_to_multiple_of=self.tokenizer.model_max_length, - return_tensors="pt" - ) - - def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): - prompts = input_ids.shape[0] - - input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if position_ids is not None: - position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if attention_mask is not None: - attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - - text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] - text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) - return text_embeddings diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch + +from transformers import CLIPTokenizer, CLIPTextModel + + +def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): + return tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=tokenizer.model_max_length, + return_tensors="pt" + ) + + +def get_extended_embeddings( + text_encoder: CLIPTextModel, + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask=None +): + model_max_length = text_encoder.config.max_position_embeddings + prompts = input_ids.shape[0] + + input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) + if position_ids is not None: + position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) + if attention_mask is not None: + attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) + + text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] + text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) + return text_embeddings -- cgit v1.2.3-54-g00ecf