diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
| commit | 3924055ed24da9b6995303cd36282eb558ba0bf0 (patch) | |
| tree | 4fed8dabcde2236e1a1e8f5738b2a0bdcfd4513b /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.gz textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.bz2 textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.zip | |
Fix
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/train_ti.py b/train_ti.py index 7f5fb49..45e730a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -484,19 +484,13 @@ def parse_args(): | |||
| 484 | help="The weight of prior preservation loss." | 484 | help="The weight of prior preservation loss." |
| 485 | ) | 485 | ) |
| 486 | parser.add_argument( | 486 | parser.add_argument( |
| 487 | "--lora_r", | 487 | "--emb_alpha", |
| 488 | type=int, | 488 | type=float, |
| 489 | default=8, | 489 | default=1.0, |
| 490 | help="Lora rank, only used if use_lora is True" | 490 | help="Embedding alpha" |
| 491 | ) | ||
| 492 | parser.add_argument( | ||
| 493 | "--lora_alpha", | ||
| 494 | type=int, | ||
| 495 | default=32, | ||
| 496 | help="Lora alpha, only used if use_lora is True" | ||
| 497 | ) | 491 | ) |
| 498 | parser.add_argument( | 492 | parser.add_argument( |
| 499 | "--lora_dropout", | 493 | "--emb_dropout", |
| 500 | type=float, | 494 | type=float, |
| 501 | default=0, | 495 | default=0, |
| 502 | help="Embedding dropout probability.", | 496 | help="Embedding dropout probability.", |
| @@ -669,9 +663,8 @@ def main(): | |||
| 669 | 663 | ||
| 670 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 664 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 671 | args.pretrained_model_name_or_path, | 665 | args.pretrained_model_name_or_path, |
| 672 | args.lora_r, | 666 | args.emb_alpha, |
| 673 | args.lora_alpha, | 667 | args.emb_dropout |
| 674 | args.lora_dropout | ||
| 675 | ) | 668 | ) |
| 676 | 669 | ||
| 677 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 670 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
