summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 20:48:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 20:48:04 +0100
commiteb0838bd2bf96d34dd779f847552291379fe543f (patch)
tree501c41a8330a06ee0b0939a47ae74c281129ab47 /models
parentFix MultiCLIPTokenizer (forgot to override encode) (diff)
downloadtextual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.gz
textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.bz2
textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.zip
Cleanup
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py11
-rw-r--r--models/clip/tokenizer.py1
2 files changed, 1 insertions, 11 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 8602142..f90e7c2 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -120,14 +120,3 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe
120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
121 text_encoder.text_model.embeddings = text_embeddings 121 text_encoder.text_model.embeddings = text_embeddings
122 return text_embeddings 122 return text_embeddings
123
124
125def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings:
126 text_encoder.text_model.embeddings.make_permanent()
127
128 text_embeddings = CLIPTextEmbeddings(text_encoder.config)
129 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
130 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
131 text_encoder.text_model.embeddings = text_embeddings
132
133 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 5e33f3e..bd0bd21 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -57,6 +57,7 @@ class MultiCLIPTokenizerItem(NamedTuple):
57class MultiCLIPTokenizer(CLIPTokenizer): 57class MultiCLIPTokenizer(CLIPTokenizer):
58 def __init__(self, *args, **kwargs): 58 def __init__(self, *args, **kwargs):
59 super().__init__(*args, **kwargs) 59 super().__init__(*args, **kwargs)
60
60 self.token_map: dict[int, list[int]] = {} 61 self.token_map: dict[int, list[int]] = {}
61 self.vector_shuffle = shuffle_none 62 self.vector_shuffle = shuffle_none
62 63