From e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 18:52:30 +0200 Subject: TI: Delta learning --- train_ti.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 8dde1ba..0ad7574 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -451,21 +451,10 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." - ) - parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." - ) - parser.add_argument( - "--emb_decay", - default=1e2, + "--emb_alpha", + default=1.0, type=float, - help="Embedding decay factor." + help="Embedding alpha." ) parser.add_argument( "--noise_timesteps", @@ -567,16 +556,16 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer in ('adam', 'adam8bit'): - args.adam_beta1 = 0.9 - elif args.optimizer == 'lion': + if args.optimizer == 'lion': args.adam_beta1 = 0.95 + else: + args.adam_beta1 = 0.9 if args.adam_beta2 is None: - if args.optimizer in ('adam', 'adam8bit'): - args.adam_beta2 = 0.999 - elif args.optimizer == 'lion': + if args.optimizer == 'lion': args.adam_beta2 = 0.98 + else: + args.adam_beta2 = 0.999 return args @@ -611,7 +600,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_alpha) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -755,10 +744,6 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, -- cgit v1.2.3-54-g00ecf