summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py21
1 files changed, 7 insertions, 14 deletions
diff --git a/train_ti.py b/train_ti.py
index 7f5fb49..45e730a 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -484,19 +484,13 @@ def parse_args():
484 help="The weight of prior preservation loss." 484 help="The weight of prior preservation loss."
485 ) 485 )
486 parser.add_argument( 486 parser.add_argument(
487 "--lora_r", 487 "--emb_alpha",
488 type=int, 488 type=float,
489 default=8, 489 default=1.0,
490 help="Lora rank, only used if use_lora is True" 490 help="Embedding alpha"
491 )
492 parser.add_argument(
493 "--lora_alpha",
494 type=int,
495 default=32,
496 help="Lora alpha, only used if use_lora is True"
497 ) 491 )
498 parser.add_argument( 492 parser.add_argument(
499 "--lora_dropout", 493 "--emb_dropout",
500 type=float, 494 type=float,
501 default=0, 495 default=0,
502 help="Embedding dropout probability.", 496 help="Embedding dropout probability.",
@@ -669,9 +663,8 @@ def main():
669 663
670 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 664 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
671 args.pretrained_model_name_or_path, 665 args.pretrained_model_name_or_path,
672 args.lora_r, 666 args.emb_alpha,
673 args.lora_alpha, 667 args.emb_dropout
674 args.lora_dropout
675 ) 668 )
676 669
677 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 670 tokenizer.set_use_vector_shuffle(args.vector_shuffle)