summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 16:30:36 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 16:30:36 +0200
commitc96073646bbb638d7d78fdd7d9fdeed08d1454b5 (patch)
tree3e0846964fa127844d652e2dee081cd67e672e6a /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.gz
textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.bz2
textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.zip
Experimental: TI via LoRA
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 2 insertions, 22 deletions
diff --git a/train_ti.py b/train_ti.py
index 5482326..0ce0056 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=0, 356 default=1e-2,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,23 +451,6 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay",
455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float,
468 help="Embedding decay factor."
469 )
470 parser.add_argument(
471 "--noise_timesteps", 454 "--noise_timesteps",
472 type=int, 455 type=int,
473 default=1000, 456 default=1000,
@@ -732,9 +715,6 @@ def main():
732 sample_scheduler=sample_scheduler, 715 sample_scheduler=sample_scheduler,
733 checkpoint_output_dir=checkpoint_output_dir, 716 checkpoint_output_dir=checkpoint_output_dir,
734 gradient_checkpointing=args.gradient_checkpointing, 717 gradient_checkpointing=args.gradient_checkpointing,
735 use_emb_decay=args.use_emb_decay,
736 emb_decay_target=args.emb_decay_target,
737 emb_decay=args.emb_decay,
738 use_ema=args.use_ema, 718 use_ema=args.use_ema,
739 ema_inv_gamma=args.ema_inv_gamma, 719 ema_inv_gamma=args.ema_inv_gamma,
740 ema_power=args.ema_power, 720 ema_power=args.ema_power,
@@ -800,7 +780,7 @@ def main():
800 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
801 781
802 optimizer = create_optimizer( 782 optimizer = create_optimizer(
803 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 783 text_encoder.text_model.embeddings.overlay.parameters(),
804 lr=args.learning_rate, 784 lr=args.learning_rate,
805 ) 785 )
806 786