summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 18:52:30 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 18:52:30 +0200
commite68cb3542e08c9f22ce8a94fd88bebe0c121ca17 (patch)
tree87fbb9d92233aa1bb7342e31aec64d6d375f41e1 /train_ti.py
parentTI: No tag dropout by default (diff)
downloadtextual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.gz
textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.bz2
textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.zip
TI: Delta learning
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py37
1 files changed, 11 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py
index 8dde1ba..0ad7574 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=0, 356 default=1e-2,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,21 +451,10 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay", 454 "--emb_alpha",
455 action="store_true", 455 default=1.0,
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float, 456 type=float,
468 help="Embedding decay factor." 457 help="Embedding alpha."
469 ) 458 )
470 parser.add_argument( 459 parser.add_argument(
471 "--noise_timesteps", 460 "--noise_timesteps",
@@ -567,16 +556,16 @@ def parse_args():
567 raise ValueError("You must specify --output_dir") 556 raise ValueError("You must specify --output_dir")
568 557
569 if args.adam_beta1 is None: 558 if args.adam_beta1 is None:
570 if args.optimizer in ('adam', 'adam8bit'): 559 if args.optimizer == 'lion':
571 args.adam_beta1 = 0.9
572 elif args.optimizer == 'lion':
573 args.adam_beta1 = 0.95 560 args.adam_beta1 = 0.95
561 else:
562 args.adam_beta1 = 0.9
574 563
575 if args.adam_beta2 is None: 564 if args.adam_beta2 is None:
576 if args.optimizer in ('adam', 'adam8bit'): 565 if args.optimizer == 'lion':
577 args.adam_beta2 = 0.999
578 elif args.optimizer == 'lion':
579 args.adam_beta2 = 0.98 566 args.adam_beta2 = 0.98
567 else:
568 args.adam_beta2 = 0.999
580 569
581 return args 570 return args
582 571
@@ -611,7 +600,7 @@ def main():
611 save_args(output_dir, args) 600 save_args(output_dir, args)
612 601
613 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 602 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
614 args.pretrained_model_name_or_path) 603 args.pretrained_model_name_or_path, args.emb_alpha)
615 604
616 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 605 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
617 tokenizer.set_dropout(args.vector_dropout) 606 tokenizer.set_dropout(args.vector_dropout)
@@ -755,10 +744,6 @@ def main():
755 tokenizer=tokenizer, 744 tokenizer=tokenizer,
756 sample_scheduler=sample_scheduler, 745 sample_scheduler=sample_scheduler,
757 checkpoint_output_dir=checkpoint_output_dir, 746 checkpoint_output_dir=checkpoint_output_dir,
758 gradient_checkpointing=args.gradient_checkpointing,
759 use_emb_decay=args.use_emb_decay,
760 emb_decay_target=args.emb_decay_target,
761 emb_decay=args.emb_decay,
762 use_ema=args.use_ema, 747 use_ema=args.use_ema,
763 ema_inv_gamma=args.ema_inv_gamma, 748 ema_inv_gamma=args.ema_inv_gamma,
764 ema_power=args.ema_power, 749 ema_power=args.ema_power,