summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /models
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py6
-rw-r--r--models/clip/prompt.py38
-rw-r--r--models/clip/util.py34
3 files changed, 39 insertions, 39 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9a23a2a..761efbc 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 self.position_embedding = embeddings.position_embedding 40 self.position_embedding = embeddings.position_embedding
41 self.initializer_factor = config.initializer_factor 41 self.initializer_factor = config.initializer_factor
42 42
43 self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item()
44
43 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.Embedding(
44 self.token_embedding.num_embeddings, 46 self.token_embedding.num_embeddings,
45 self.token_embedding.embedding_dim, 47 self.token_embedding.embedding_dim,
@@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
99 101
100 return embeds 102 return embeds
101 103
102 def normalize(self, target: float = 0.4, lambda_: float = 1.0): 104 def normalize(self, target: Optional[float] = None, lambda_: float = 1.0):
105 if target is None:
106 target = self.decay_target
103 w = self.temp_token_embedding.weight 107 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 108 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize( 109 w[self.temp_token_ids] = F.normalize(
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
deleted file mode 100644
index a7380be..0000000
--- a/models/clip/prompt.py
+++ /dev/null
@@ -1,38 +0,0 @@
1from typing import Union, Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8class PromptProcessor():
9 def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder
12
13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer(
15 prompt,
16 padding="do_not_pad",
17 ).input_ids
18
19 def unify_input_ids(self, input_ids: list[list[int]]):
20 return self.tokenizer.pad(
21 {"input_ids": input_ids},
22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt"
25 )
26
27 def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
28 prompts = input_ids.shape[0]
29
30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if position_ids is not None:
32 position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 if attention_mask is not None:
34 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
35
36 text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
37 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
38 return text_embeddings
diff --git a/models/clip/util.py b/models/clip/util.py
new file mode 100644
index 0000000..8de8c19
--- /dev/null
+++ b/models/clip/util.py
@@ -0,0 +1,34 @@
1from typing import Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]):
9 return tokenizer.pad(
10 {"input_ids": input_ids},
11 padding=True,
12 pad_to_multiple_of=tokenizer.model_max_length,
13 return_tensors="pt"
14 )
15
16
17def get_extended_embeddings(
18 text_encoder: CLIPTextModel,
19 input_ids: torch.LongTensor,
20 position_ids: Optional[torch.LongTensor] = None,
21 attention_mask=None
22):
23 model_max_length = text_encoder.config.max_position_embeddings
24 prompts = input_ids.shape[0]
25
26 input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device)
27 if position_ids is not None:
28 position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device)
29 if attention_mask is not None:
30 attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device)
31
32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
34 return text_embeddings