diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
commit | 3924055ed24da9b6995303cd36282eb558ba0bf0 (patch) | |
tree | 4fed8dabcde2236e1a1e8f5738b2a0bdcfd4513b /train_lora.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.gz textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.bz2 textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.zip |
Fix
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index d5dde02..5c78664 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -507,6 +507,12 @@ def parse_args(): | |||
507 | help="The weight of prior preservation loss." | 507 | help="The weight of prior preservation loss." |
508 | ) | 508 | ) |
509 | parser.add_argument( | 509 | parser.add_argument( |
510 | "--emb_alpha", | ||
511 | type=float, | ||
512 | default=1.0, | ||
513 | help="Embedding alpha" | ||
514 | ) | ||
515 | parser.add_argument( | ||
510 | "--emb_dropout", | 516 | "--emb_dropout", |
511 | type=float, | 517 | type=float, |
512 | default=0, | 518 | default=0, |
@@ -660,7 +666,10 @@ def main(): | |||
660 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
661 | 667 | ||
662 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
663 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
670 | args.emb_alpha, | ||
671 | args.emb_dropout | ||
672 | ) | ||
664 | 673 | ||
665 | unet_config = LoraConfig( | 674 | unet_config = LoraConfig( |
666 | r=args.lora_r, | 675 | r=args.lora_r, |