summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 14:45:37 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 14:45:37 +0200
commit3924055ed24da9b6995303cd36282eb558ba0bf0 (patch)
tree4fed8dabcde2236e1a1e8f5738b2a0bdcfd4513b /train_lora.py
parentFix (diff)
downloadtextual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.gz
textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.bz2
textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.zip
Fix
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index d5dde02..5c78664 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -507,6 +507,12 @@ def parse_args():
507 help="The weight of prior preservation loss." 507 help="The weight of prior preservation loss."
508 ) 508 )
509 parser.add_argument( 509 parser.add_argument(
510 "--emb_alpha",
511 type=float,
512 default=1.0,
513 help="Embedding alpha"
514 )
515 parser.add_argument(
510 "--emb_dropout", 516 "--emb_dropout",
511 type=float, 517 type=float,
512 default=0, 518 default=0,
@@ -660,7 +666,10 @@ def main():
660 save_args(output_dir, args) 666 save_args(output_dir, args)
661 667
662 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 668 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
663 args.pretrained_model_name_or_path, args.emb_dropout) 669 args.pretrained_model_name_or_path,
670 args.emb_alpha,
671 args.emb_dropout
672 )
664 673
665 unet_config = LoraConfig( 674 unet_config = LoraConfig(
666 r=args.lora_r, 675 r=args.lora_r,