diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
| commit | 21d70916f66e74a87c631a06b70774954b085b48 (patch) | |
| tree | d1b443b9270f45ae6936f3acb565f767c7c65b1f /train_lora.py | |
| parent | Run PTI only if placeholder tokens arg isn't empty (diff) | |
| download | textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2 textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip | |
Fix
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index daf1f6c..476efcf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -548,15 +548,18 @@ def parse_args(): | |||
| 548 | if args.project is None: | 548 | if args.project is None: |
| 549 | raise ValueError("You must specify --project") | 549 | raise ValueError("You must specify --project") |
| 550 | 550 | ||
| 551 | if args.initializer_tokens is None: | ||
| 552 | args.initializer_tokens = [] | ||
| 553 | |||
| 554 | if args.placeholder_tokens is None: | ||
| 555 | args.placeholder_tokens = [] | ||
| 556 | |||
| 551 | if isinstance(args.placeholder_tokens, str): | 557 | if isinstance(args.placeholder_tokens, str): |
| 552 | args.placeholder_tokens = [args.placeholder_tokens] | 558 | args.placeholder_tokens = [args.placeholder_tokens] |
| 553 | 559 | ||
| 554 | if isinstance(args.initializer_tokens, str): | 560 | if isinstance(args.initializer_tokens, str): |
| 555 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 561 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 556 | 562 | ||
| 557 | if len(args.initializer_tokens) == 0: | ||
| 558 | raise ValueError("You must specify --initializer_tokens") | ||
| 559 | |||
| 560 | if len(args.placeholder_tokens) == 0: | 563 | if len(args.placeholder_tokens) == 0: |
| 561 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 564 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 562 | 565 | ||
| @@ -884,7 +887,7 @@ def main(): | |||
| 884 | num_pti_epochs = math.ceil( | 887 | num_pti_epochs = math.ceil( |
| 885 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 888 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
| 886 | ) * args.pti_gradient_accumulation_steps | 889 | ) * args.pti_gradient_accumulation_steps |
| 887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 890 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) |
| 888 | 891 | ||
| 889 | pti_optimizer = create_optimizer( | 892 | pti_optimizer = create_optimizer( |
| 890 | [ | 893 | [ |
| @@ -915,7 +918,7 @@ def main(): | |||
| 915 | # -- | 918 | # -- |
| 916 | sample_output_dir=pti_sample_output_dir, | 919 | sample_output_dir=pti_sample_output_dir, |
| 917 | checkpoint_output_dir=pti_checkpoint_output_dir, | 920 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 918 | sample_frequency=pti_sample_frequency, | 921 | sample_frequency=math.inf, |
| 919 | placeholder_tokens=args.placeholder_tokens, | 922 | placeholder_tokens=args.placeholder_tokens, |
| 920 | placeholder_token_ids=placeholder_token_ids, | 923 | placeholder_token_ids=placeholder_token_ids, |
| 921 | use_emb_decay=args.use_emb_decay, | 924 | use_emb_decay=args.use_emb_decay, |
