diff options
author | Volpeon <git@volpeon.ink> | 2023-03-31 14:54:15 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-31 14:54:15 +0200 |
commit | 5acae38f9b995fbaeb42a1504cce88bd18154f12 (patch) | |
tree | 28abdb148fc133782fb5ee55b157cf1b12327c9d /train_lora.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.tar.gz textual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.tar.bz2 textual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.zip |
Fix
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 7b54ef8..d89b18d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -625,7 +625,6 @@ def main(): | |||
625 | dropout=args.tag_dropout, | 625 | dropout=args.tag_dropout, |
626 | shuffle=not args.no_tag_shuffle, | 626 | shuffle=not args.no_tag_shuffle, |
627 | template_key=args.train_data_template, | 627 | template_key=args.train_data_template, |
628 | placeholder_tokens=args.placeholder_tokens, | ||
629 | valid_set_size=args.valid_set_size, | 628 | valid_set_size=args.valid_set_size, |
630 | train_set_pad=args.train_set_pad, | 629 | train_set_pad=args.train_set_pad, |
631 | valid_set_pad=args.valid_set_pad, | 630 | valid_set_pad=args.valid_set_pad, |
@@ -636,9 +635,10 @@ def main(): | |||
636 | datamodule.setup() | 635 | datamodule.setup() |
637 | 636 | ||
638 | num_train_epochs = args.num_train_epochs | 637 | num_train_epochs = args.num_train_epochs |
638 | sample_frequency = args.sample_frequency | ||
639 | if num_train_epochs is None: | 639 | if num_train_epochs is None: |
640 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 640 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
641 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) | 641 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
642 | 642 | ||
643 | optimizer = create_optimizer( | 643 | optimizer = create_optimizer( |
644 | itertools.chain( | 644 | itertools.chain( |