summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 14:14:00 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 14:14:00 +0200
commit21d70916f66e74a87c631a06b70774954b085b48 (patch)
treed1b443b9270f45ae6936f3acb565f767c7c65b1f /train_lora.py
parentRun PTI only if placeholder tokens arg isn't empty (diff)
downloadtextual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz
textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2
textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip
Fix
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py
index daf1f6c..476efcf 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -548,15 +548,18 @@ def parse_args():
548 if args.project is None: 548 if args.project is None:
549 raise ValueError("You must specify --project") 549 raise ValueError("You must specify --project")
550 550
551 if args.initializer_tokens is None:
552 args.initializer_tokens = []
553
554 if args.placeholder_tokens is None:
555 args.placeholder_tokens = []
556
551 if isinstance(args.placeholder_tokens, str): 557 if isinstance(args.placeholder_tokens, str):
552 args.placeholder_tokens = [args.placeholder_tokens] 558 args.placeholder_tokens = [args.placeholder_tokens]
553 559
554 if isinstance(args.initializer_tokens, str): 560 if isinstance(args.initializer_tokens, str):
555 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 561 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
556 562
557 if len(args.initializer_tokens) == 0:
558 raise ValueError("You must specify --initializer_tokens")
559
560 if len(args.placeholder_tokens) == 0: 563 if len(args.placeholder_tokens) == 0:
561 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 564 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
562 565
@@ -884,7 +887,7 @@ def main():
884 num_pti_epochs = math.ceil( 887 num_pti_epochs = math.ceil(
885 args.num_pti_steps / len(pti_datamodule.train_dataset) 888 args.num_pti_steps / len(pti_datamodule.train_dataset)
886 ) * args.pti_gradient_accumulation_steps 889 ) * args.pti_gradient_accumulation_steps
887 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) 890 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
888 891
889 pti_optimizer = create_optimizer( 892 pti_optimizer = create_optimizer(
890 [ 893 [
@@ -915,7 +918,7 @@ def main():
915 # -- 918 # --
916 sample_output_dir=pti_sample_output_dir, 919 sample_output_dir=pti_sample_output_dir,
917 checkpoint_output_dir=pti_checkpoint_output_dir, 920 checkpoint_output_dir=pti_checkpoint_output_dir,
918 sample_frequency=pti_sample_frequency, 921 sample_frequency=math.inf,
919 placeholder_tokens=args.placeholder_tokens, 922 placeholder_tokens=args.placeholder_tokens,
920 placeholder_token_ids=placeholder_token_ids, 923 placeholder_token_ids=placeholder_token_ids,
921 use_emb_decay=args.use_emb_decay, 924 use_emb_decay=args.use_emb_decay,