diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 09:25:05 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 09:25:05 +0200 |
commit | 8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68 (patch) | |
tree | 25e6e51fb8e8c691ce5855354a6167d75070b079 | |
parent | Support LoRA training for token embeddings (diff) | |
download | textual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.tar.gz textual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.tar.bz2 textual-inversion-diff-8ea42d6ce516b7d0c43fc7a1e3d5e9db33d72c68.zip |
LoRA: Apply to out layers as well
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 167b17a..70f0dc8 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -32,8 +32,10 @@ from training.util import AverageMeter, save_args | |||
32 | from util.files import load_config, load_embeddings_from_dir | 32 | from util.files import load_config, load_embeddings_from_dir |
33 | 33 | ||
34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
35 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 35 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
36 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | 36 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key] |
37 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | ||
38 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | ||
37 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 39 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] |
38 | 40 | ||
39 | 41 | ||