summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 09:25:05 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 09:25:05 +0200
commit8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68 (patch)
tree25e6e51fb8e8c691ce5855354a6167d75070b079
parentSupport LoRA training for token embeddings (diff)
downloadtextual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.tar.gz
textual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.tar.bz2
textual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.zip
LoRA: Apply to out layers as well
-rw-r--r--train_lora.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index 167b17a..70f0dc8 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -32,8 +32,10 @@ from training.util import AverageMeter, save_args
32from util.files import load_config, load_embeddings_from_dir 32from util.files import load_config, load_embeddings_from_dir
33 33
34# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 34# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
35UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 35UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"]
36TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] 36UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key]
37TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"]
38TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"]
37TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] 39TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"]
38 40
39 41