diff options
author | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
commit | 16b92605a59d59c65789c89b54bb97da51908056 (patch) | |
tree | b0cbf8677897c3f44c736b710fd034eb2c5de6a0 /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.gz textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.bz2 textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.zip |
Embedding normalization: Ignore tensors with grad = 0
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -86,7 +86,7 @@ def parse_args(): | |||
86 | help="Number of vectors per embedding." | 86 | help="Number of vectors per embedding." |
87 | ) | 87 | ) |
88 | parser.add_argument( | 88 | parser.add_argument( |
89 | "--simultaneous", | 89 | "--sequential", |
90 | action="store_true", | 90 | action="store_true", |
91 | ) | 91 | ) |
92 | parser.add_argument( | 92 | parser.add_argument( |
@@ -293,7 +293,7 @@ def parse_args(): | |||
293 | "--optimizer", | 293 | "--optimizer", |
294 | type=str, | 294 | type=str, |
295 | default="adam", | 295 | default="adam", |
296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit"]' |
297 | ) | 297 | ) |
298 | parser.add_argument( | 298 | parser.add_argument( |
299 | "--adam_beta1", | 299 | "--adam_beta1", |
@@ -343,6 +343,11 @@ def parse_args(): | |||
343 | help="How often to save a checkpoint and sample image (in epochs)", | 343 | help="How often to save a checkpoint and sample image (in epochs)", |
344 | ) | 344 | ) |
345 | parser.add_argument( | 345 | parser.add_argument( |
346 | "--no_milestone_checkpoints", | ||
347 | action='store_true', | ||
348 | help="If checkpoints are saved on maximum accuracy", | ||
349 | ) | ||
350 | parser.add_argument( | ||
346 | "--sample_frequency", | 351 | "--sample_frequency", |
347 | type=int, | 352 | type=int, |
348 | default=1, | 353 | default=1, |
@@ -480,7 +485,7 @@ def parse_args(): | |||
480 | if len(args.placeholder_tokens) != len(args.num_vectors): | 485 | if len(args.placeholder_tokens) != len(args.num_vectors): |
481 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
482 | 487 | ||
483 | if not args.simultaneous: | 488 | if args.sequential: |
484 | if isinstance(args.train_data_template, str): | 489 | if isinstance(args.train_data_template, str): |
485 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 490 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
486 | 491 | ||
@@ -586,13 +591,6 @@ def main(): | |||
586 | eps=args.adam_epsilon, | 591 | eps=args.adam_epsilon, |
587 | amsgrad=args.adam_amsgrad, | 592 | amsgrad=args.adam_amsgrad, |
588 | ) | 593 | ) |
589 | elif args.optimizer == 'lion': | ||
590 | try: | ||
591 | from lion_pytorch import Lion | ||
592 | except ImportError: | ||
593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
594 | |||
595 | create_optimizer = partial(Lion, use_triton=True) | ||
596 | else: | 594 | else: |
597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 595 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
598 | 596 | ||
@@ -615,6 +613,7 @@ def main(): | |||
615 | num_train_epochs=args.num_train_epochs, | 613 | num_train_epochs=args.num_train_epochs, |
616 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
617 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
616 | milestone_checkpoints=not args.no_milestone_checkpoints, | ||
618 | global_step_offset=global_step_offset, | 617 | global_step_offset=global_step_offset, |
619 | # -- | 618 | # -- |
620 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
@@ -715,7 +714,7 @@ def main(): | |||
715 | 714 | ||
716 | plot_metrics(metrics, metrics_output_file) | 715 | plot_metrics(metrics, metrics_output_file) |
717 | 716 | ||
718 | if args.simultaneous: | 717 | if not args.sequential: |
719 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 718 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
720 | else: | 719 | else: |
721 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 720 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |