diff options
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 167b17a..70f0dc8 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -32,8 +32,10 @@ from training.util import AverageMeter, save_args | |||
32 | from util.files import load_config, load_embeddings_from_dir | 32 | from util.files import load_config, load_embeddings_from_dir |
33 | 33 | ||
34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
35 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 35 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
36 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | 36 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key] |
37 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | ||
38 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | ||
37 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 39 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] |
38 | 40 | ||
39 | 41 | ||