From adc52fb8821a496bc8d78235bf10466b39df03e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 19:19:52 +0100 Subject: Updates --- models/clip/embeddings.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings + + +def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: + text_encoder.text_model.embeddings.make_permanent() + + text_embeddings = CLIPTextEmbeddings(text_encoder.config) + text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding + text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding + text_encoder.text_model.embeddings = text_embeddings + + return text_embeddings -- cgit v1.2.3-70-g09d2