diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/train_ti.py b/train_ti.py index 7f5fb49..45e730a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -484,19 +484,13 @@ def parse_args(): | |||
484 | help="The weight of prior preservation loss." | 484 | help="The weight of prior preservation loss." |
485 | ) | 485 | ) |
486 | parser.add_argument( | 486 | parser.add_argument( |
487 | "--lora_r", | 487 | "--emb_alpha", |
488 | type=int, | 488 | type=float, |
489 | default=8, | 489 | default=1.0, |
490 | help="Lora rank, only used if use_lora is True" | 490 | help="Embedding alpha" |
491 | ) | ||
492 | parser.add_argument( | ||
493 | "--lora_alpha", | ||
494 | type=int, | ||
495 | default=32, | ||
496 | help="Lora alpha, only used if use_lora is True" | ||
497 | ) | 491 | ) |
498 | parser.add_argument( | 492 | parser.add_argument( |
499 | "--lora_dropout", | 493 | "--emb_dropout", |
500 | type=float, | 494 | type=float, |
501 | default=0, | 495 | default=0, |
502 | help="Embedding dropout probability.", | 496 | help="Embedding dropout probability.", |
@@ -669,9 +663,8 @@ def main(): | |||
669 | 663 | ||
670 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 664 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
671 | args.pretrained_model_name_or_path, | 665 | args.pretrained_model_name_or_path, |
672 | args.lora_r, | 666 | args.emb_alpha, |
673 | args.lora_alpha, | 667 | args.emb_dropout |
674 | args.lora_dropout | ||
675 | ) | 668 | ) |
676 | 669 | ||
677 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 670 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |