diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /models/clip | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 29 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 23 | ||||
-rw-r--r-- | models/clip/util.py | 17 |
3 files changed, 41 insertions, 28 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding | |||
14 | 14 | ||
15 | 15 | ||
16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): | 17 | def __init__( |
18 | self, | ||
19 | config: CLIPTextConfig, | ||
20 | embeddings: CLIPTextEmbeddings, | ||
21 | alpha: int = 8, | ||
22 | dropout: float = 0.0, | ||
23 | ): | ||
18 | super().__init__(config) | 24 | super().__init__(config) |
19 | 25 | ||
20 | self.position_embedding = embeddings.position_embedding | 26 | self.position_embedding = embeddings.position_embedding |
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
28 | self.token_embedding.weight = embeddings.token_embedding.weight | 34 | self.token_embedding.weight = embeddings.token_embedding.weight |
29 | 35 | ||
30 | def resize(self, size: int): | 36 | def resize(self, size: int): |
31 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) | 37 | self.token_embedding = self.token_embedding.new_resized( |
38 | size, self.initializer_factor | ||
39 | ) | ||
32 | 40 | ||
33 | def add_embed( | 41 | def add_embed( |
34 | self, | 42 | self, |
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
46 | initializer = [initializer] | 54 | initializer = [initializer] |
47 | 55 | ||
48 | if isinstance(initializer, list): | 56 | if isinstance(initializer, list): |
49 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 57 | initializer = (initializer * len(token_ids))[: len(token_ids)] |
50 | 58 | ||
51 | with torch.no_grad(): | 59 | with torch.no_grad(): |
52 | initializer = self.get_embed(initializer) | 60 | initializer = self.get_embed(initializer) |
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
76 | 84 | ||
77 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
78 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
79 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 87 | input_ids = torch.tensor( |
88 | input_ids, device=self.token_embedding.weight.device, dtype=torch.long | ||
89 | ) | ||
80 | 90 | ||
81 | return self.token_embedding(input_ids) | 91 | return self.token_embedding(input_ids) |
82 | 92 | ||
83 | 93 | ||
84 | def patch_managed_embeddings( | 94 | def patch_managed_embeddings( |
85 | text_encoder: CLIPTextModel, | 95 | text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 |
86 | alpha: int = 8, | ||
87 | dropout: float = 0.0 | ||
88 | ) -> ManagedCLIPTextEmbeddings: | 96 | ) -> ManagedCLIPTextEmbeddings: |
89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | 97 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): |
90 | return text_encoder.text_model.embeddings | 98 | return text_encoder.text_model.embeddings |
91 | 99 | ||
92 | text_embeddings = ManagedCLIPTextEmbeddings( | 100 | text_embeddings = ManagedCLIPTextEmbeddings( |
93 | text_encoder.config, | 101 | text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout |
94 | text_encoder.text_model.embeddings, | ||
95 | alpha, | ||
96 | dropout | ||
97 | ) | 102 | ) |
98 | text_encoder.text_model.embeddings = text_embeddings | 103 | text_encoder.text_model.embeddings = text_embeddings |
99 | return text_embeddings | 104 | return text_embeddings |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
91 | self.vector_shuffle = shuffle_none | 91 | self.vector_shuffle = shuffle_none |
92 | 92 | ||
93 | def add_multi_tokens( | 93 | def add_multi_tokens( |
94 | self, | 94 | self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 |
95 | new_tokens: Union[str, list[str]], | ||
96 | num_vectors: Union[int, list[int]] = 1 | ||
97 | ) -> Union[list[int], list[list[int]]]: | 95 | ) -> Union[list[int], list[list[int]]]: |
98 | if isinstance(new_tokens, list): | 96 | if isinstance(new_tokens, list): |
99 | if isinstance(num_vectors, int): | 97 | if isinstance(num_vectors, int): |
100 | num_vectors = [num_vectors] * len(new_tokens) | 98 | num_vectors = [num_vectors] * len(new_tokens) |
101 | 99 | ||
102 | if len(num_vectors) != len(new_tokens): | 100 | if len(num_vectors) != len(new_tokens): |
103 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | 101 | raise ValueError( |
102 | "Expected new_tokens and num_vectors to have the same len" | ||
103 | ) | ||
104 | 104 | ||
105 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | 105 | return [ |
106 | self.add_multi_tokens(new_token, vecs) | ||
107 | for new_token, vecs in zip(new_tokens, num_vectors) | ||
108 | ] | ||
106 | 109 | ||
107 | if isinstance(num_vectors, list): | 110 | if isinstance(num_vectors, list): |
108 | raise ValueError("Expected num_vectors to be int for single token") | 111 | raise ValueError("Expected num_vectors to be int for single token") |
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
129 | return [id] | 132 | return [id] |
130 | 133 | ||
131 | def expand_ids(self, ids: list[int]): | 134 | def expand_ids(self, ids: list[int]): |
132 | return [ | 135 | return [new_id for id in ids for new_id in self.expand_id(id)] |
133 | new_id | ||
134 | for id in ids | ||
135 | for new_id in self.expand_id(id) | ||
136 | ] | ||
137 | 136 | ||
138 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): | 137 | def expand_batched_ids( |
138 | self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] | ||
139 | ): | ||
139 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): | 140 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
140 | return [self.expand_ids(batch) for batch in input_ids] | 141 | return [self.expand_ids(batch) for batch in input_ids] |
141 | else: | 142 | else: |
diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -5,27 +5,32 @@ import torch | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
6 | 6 | ||
7 | 7 | ||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): | 8 | def unify_input_ids( |
9 | tokenizer: CLIPTokenizer, | ||
10 | input_ids: list[list[int]], | ||
11 | max_length: Optional[int] = None, | ||
12 | ): | ||
9 | if max_length is None: | 13 | if max_length is None: |
10 | return tokenizer.pad( | 14 | return tokenizer.pad( |
11 | {"input_ids": input_ids}, | 15 | {"input_ids": input_ids}, |
12 | padding=True, | 16 | padding=True, |
13 | pad_to_multiple_of=tokenizer.model_max_length, | 17 | pad_to_multiple_of=tokenizer.model_max_length, |
14 | return_tensors="pt" | 18 | return_tensors="pt", |
15 | ) | 19 | ) |
16 | else: | 20 | else: |
17 | return tokenizer.pad( | 21 | return tokenizer.pad( |
18 | {"input_ids": input_ids}, | 22 | {"input_ids": input_ids}, |
19 | padding="max_length", | 23 | padding="max_length", |
20 | max_length=max_length, | 24 | max_length=max_length, |
21 | return_tensors="pt" | 25 | return_tensors="pt", |
22 | ) | 26 | ) |
23 | 27 | ||
28 | |||
24 | def get_extended_embeddings( | 29 | def get_extended_embeddings( |
25 | text_encoder: CLIPTextModel, | 30 | text_encoder: CLIPTextModel, |
26 | input_ids: torch.LongTensor, | 31 | input_ids: torch.LongTensor, |
27 | position_ids: Optional[torch.LongTensor] = None, | 32 | position_ids: Optional[torch.LongTensor] = None, |
28 | attention_mask=None | 33 | attention_mask=None, |
29 | ): | 34 | ): |
30 | model_max_length = text_encoder.config.max_position_embeddings | 35 | model_max_length = text_encoder.config.max_position_embeddings |
31 | prompts = input_ids.shape[0] | 36 | prompts = input_ids.shape[0] |
@@ -36,6 +41,8 @@ def get_extended_embeddings( | |||
36 | if attention_mask is not None: | 41 | if attention_mask is not None: |
37 | attention_mask = attention_mask.view((-1, model_max_length)) | 42 | attention_mask = attention_mask.view((-1, model_max_length)) |
38 | 43 | ||
39 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 44 | text_embeddings = text_encoder( |
45 | input_ids, position_ids=position_ids, attention_mask=attention_mask | ||
46 | )[0] | ||
40 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 47 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
41 | return text_embeddings | 48 | return text_embeddings |