summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
commitb31fcb741432076f7e2f3ec9423ad935a08c6671 (patch)
tree2ab052d3bd617a56c4ea388c200da52cff39ba37 /models/clip/embeddings.py
parentFix for latest PEFT (diff)
downloadtextual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip
Support LoRA training for token embeddings
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2b23bd3..7c7f2ac 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -86,6 +86,9 @@ def patch_managed_embeddings(
86 alpha: int = 8, 86 alpha: int = 8,
87 dropout: float = 0.0 87 dropout: float = 0.0
88) -> ManagedCLIPTextEmbeddings: 88) -> ManagedCLIPTextEmbeddings:
89 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
90 return text_encoder.text_model.embeddings
91
89 text_embeddings = ManagedCLIPTextEmbeddings( 92 text_embeddings = ManagedCLIPTextEmbeddings(
90 text_encoder.config, 93 text_encoder.config,
91 text_encoder.text_model.embeddings, 94 text_encoder.text_model.embeddings,