diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
commit | 7b149930bb53b93db74106ad20a30abf4b114f9b (patch) | |
tree | 67c2ccbce2a9838ad8a020ee527b19113e67e30a /models | |
parent | Added TI decay start offset (diff) | |
download | textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2 textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip |
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 6 | ||||
-rw-r--r-- | models/clip/prompt.py | 38 | ||||
-rw-r--r-- | models/clip/util.py | 34 |
3 files changed, 39 insertions, 39 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
42 | 42 | ||
43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
44 | |||
43 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.Embedding( |
44 | self.token_embedding.num_embeddings, | 46 | self.token_embedding.num_embeddings, |
45 | self.token_embedding.embedding_dim, | 47 | self.token_embedding.embedding_dim, |
@@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
99 | 101 | ||
100 | return embeds | 102 | return embeds |
101 | 103 | ||
102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | 104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): |
105 | if target is None: | ||
106 | target = self.decay_target | ||
103 | w = self.temp_token_embedding.weight | 107 | w = self.temp_token_embedding.weight |
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
105 | w[self.temp_token_ids] = F.normalize( | 109 | w[self.temp_token_ids] = F.normalize( |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null | |||
@@ -1,38 +0,0 @@ | |||
1 | from typing import Union, Optional | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | class PromptProcessor(): | ||
9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
10 | self.tokenizer = tokenizer | ||
11 | self.text_encoder = text_encoder | ||
12 | |||
13 | def get_input_ids(self, prompt: Union[str, list[str]]): | ||
14 | return self.tokenizer( | ||
15 | prompt, | ||
16 | padding="do_not_pad", | ||
17 | ).input_ids | ||
18 | |||
19 | def unify_input_ids(self, input_ids: list[list[int]]): | ||
20 | return self.tokenizer.pad( | ||
21 | {"input_ids": input_ids}, | ||
22 | padding=True, | ||
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
24 | return_tensors="pt" | ||
25 | ) | ||
26 | |||
27 | def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): | ||
28 | prompts = input_ids.shape[0] | ||
29 | |||
30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
31 | if position_ids is not None: | ||
32 | position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
33 | if attention_mask is not None: | ||
34 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
35 | |||
36 | text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
37 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
38 | return text_embeddings | ||
diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py | |||
@@ -0,0 +1,34 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | ||
9 | return tokenizer.pad( | ||
10 | {"input_ids": input_ids}, | ||
11 | padding=True, | ||
12 | pad_to_multiple_of=tokenizer.model_max_length, | ||
13 | return_tensors="pt" | ||
14 | ) | ||
15 | |||
16 | |||
17 | def get_extended_embeddings( | ||
18 | text_encoder: CLIPTextModel, | ||
19 | input_ids: torch.LongTensor, | ||
20 | position_ids: Optional[torch.LongTensor] = None, | ||
21 | attention_mask=None | ||
22 | ): | ||
23 | model_max_length = text_encoder.config.max_position_embeddings | ||
24 | prompts = input_ids.shape[0] | ||
25 | |||
26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
27 | if position_ids is not None: | ||
28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
29 | if attention_mask is not None: | ||
30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | ||
31 | |||
32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
34 | return text_embeddings | ||