summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
commit30b557c8e1f03b4748ac3efca599ff51d66561cb (patch)
tree59aaacde83a7a44dc267c64455f6dc2cfb90c01f /train_ti.py
parentImproved sparse embeddings (diff)
downloadtextual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip
TI: Bring back old embedding decay
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index a9a2333..4366c9e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=1e-2, 356 default=0,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,10 +451,21 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--emb_alpha", 454 "--use_emb_decay",
455 default=1.0, 455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e+2,
456 type=float, 467 type=float,
457 help="Embedding alpha." 468 help="Embedding decay factor."
458 ) 469 )
459 parser.add_argument( 470 parser.add_argument(
460 "--noise_timesteps", 471 "--noise_timesteps",
@@ -600,7 +611,7 @@ def main():
600 save_args(output_dir, args) 611 save_args(output_dir, args)
601 612
602 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 613 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
603 args.pretrained_model_name_or_path, args.emb_alpha) 614 args.pretrained_model_name_or_path)
604 615
605 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 616 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
606 tokenizer.set_dropout(args.vector_dropout) 617 tokenizer.set_dropout(args.vector_dropout)
@@ -744,6 +755,9 @@ def main():
744 tokenizer=tokenizer, 755 tokenizer=tokenizer,
745 sample_scheduler=sample_scheduler, 756 sample_scheduler=sample_scheduler,
746 checkpoint_output_dir=checkpoint_output_dir, 757 checkpoint_output_dir=checkpoint_output_dir,
758 use_emb_decay=args.use_emb_decay,
759 emb_decay_target=args.emb_decay_target,
760 emb_decay=args.emb_decay,
747 use_ema=args.use_ema, 761 use_ema=args.use_ema,
748 ema_inv_gamma=args.ema_inv_gamma, 762 ema_inv_gamma=args.ema_inv_gamma,
749 ema_power=args.ema_power, 763 ema_power=args.ema_power,