From 16b92605a59d59c65789c89b54bb97da51908056 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 09:09:50 +0100 Subject: Embedding normalization: Ignore tensors with grad = 0 --- train_ti.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -86,7 +86,7 @@ def parse_args(): help="Number of vectors per embedding." ) parser.add_argument( - "--simultaneous", + "--sequential", action="store_true", ) parser.add_argument( @@ -293,7 +293,7 @@ def parse_args(): "--optimizer", type=str, default="adam", - help='Optimizer to use ["adam", "adam8bit", "lion"]' + help='Optimizer to use ["adam", "adam8bit"]' ) parser.add_argument( "--adam_beta1", @@ -342,6 +342,11 @@ def parse_args(): default=5, help="How often to save a checkpoint and sample image (in epochs)", ) + parser.add_argument( + "--no_milestone_checkpoints", + action='store_true', + help="If checkpoints are saved on maximum accuracy", + ) parser.add_argument( "--sample_frequency", type=int, @@ -480,7 +485,7 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if not args.simultaneous: + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -586,13 +591,6 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'lion': - try: - from lion_pytorch import Lion - except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") - - create_optimizer = partial(Lion, use_triton=True) else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -615,6 +613,7 @@ def main(): num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, + milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, # -- tokenizer=tokenizer, @@ -715,7 +714,7 @@ def main(): plot_metrics(metrics, metrics_output_file) - if args.simultaneous: + if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: for i, placeholder_token, initializer_token, num_vectors, data_template in zip( -- cgit v1.2.3-54-g00ecf