diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index d5dde02..5c78664 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -507,6 +507,12 @@ def parse_args(): | |||
507 | help="The weight of prior preservation loss." | 507 | help="The weight of prior preservation loss." |
508 | ) | 508 | ) |
509 | parser.add_argument( | 509 | parser.add_argument( |
510 | "--emb_alpha", | ||
511 | type=float, | ||
512 | default=1.0, | ||
513 | help="Embedding alpha" | ||
514 | ) | ||
515 | parser.add_argument( | ||
510 | "--emb_dropout", | 516 | "--emb_dropout", |
511 | type=float, | 517 | type=float, |
512 | default=0, | 518 | default=0, |
@@ -660,7 +666,10 @@ def main(): | |||
660 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
661 | 667 | ||
662 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
663 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
670 | args.emb_alpha, | ||
671 | args.emb_dropout | ||
672 | ) | ||
664 | 673 | ||
665 | unet_config = LoraConfig( | 674 | unet_config = LoraConfig( |
666 | r=args.lora_r, | 675 | r=args.lora_r, |