summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 22:13:55 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 22:13:55 +0200
commit208e48134e324e934ad964bdc61880cc923f4c0d (patch)
treec215f6c201c04b0b2d18ba0df230fb4c5e622985 /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.gz
textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.bz2
textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.zip
Revert
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py52
1 files changed, 46 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 26ac384..5482326 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,7 +1,6 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
5from functools import partial 4from functools import partial
6from pathlib import Path 5from pathlib import Path
7import math 6import math
@@ -308,6 +307,26 @@ def parse_args():
308 help="Minimum learning rate in the lr scheduler." 307 help="Minimum learning rate in the lr scheduler."
309 ) 308 )
310 parser.add_argument( 309 parser.add_argument(
310 "--use_ema",
311 action="store_true",
312 help="Whether to use EMA model."
313 )
314 parser.add_argument(
315 "--ema_inv_gamma",
316 type=float,
317 default=1.0
318 )
319 parser.add_argument(
320 "--ema_power",
321 type=float,
322 default=4/5
323 )
324 parser.add_argument(
325 "--ema_max_decay",
326 type=float,
327 default=0.9999
328 )
329 parser.add_argument(
311 "--optimizer", 330 "--optimizer",
312 type=str, 331 type=str,
313 default="dadan", 332 default="dadan",
@@ -334,7 +353,7 @@ def parse_args():
334 parser.add_argument( 353 parser.add_argument(
335 "--adam_weight_decay", 354 "--adam_weight_decay",
336 type=float, 355 type=float,
337 default=1e-2, 356 default=0,
338 help="Weight decay to use." 357 help="Weight decay to use."
339 ) 358 )
340 parser.add_argument( 359 parser.add_argument(
@@ -432,6 +451,23 @@ def parse_args():
432 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
433 ) 452 )
434 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay",
455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float,
468 help="Embedding decay factor."
469 )
470 parser.add_argument(
435 "--noise_timesteps", 471 "--noise_timesteps",
436 type=int, 472 type=int,
437 default=1000, 473 default=1000,
@@ -696,6 +732,13 @@ def main():
696 sample_scheduler=sample_scheduler, 732 sample_scheduler=sample_scheduler,
697 checkpoint_output_dir=checkpoint_output_dir, 733 checkpoint_output_dir=checkpoint_output_dir,
698 gradient_checkpointing=args.gradient_checkpointing, 734 gradient_checkpointing=args.gradient_checkpointing,
735 use_emb_decay=args.use_emb_decay,
736 emb_decay_target=args.emb_decay_target,
737 emb_decay=args.emb_decay,
738 use_ema=args.use_ema,
739 ema_inv_gamma=args.ema_inv_gamma,
740 ema_power=args.ema_power,
741 ema_max_decay=args.ema_max_decay,
699 sample_batch_size=args.sample_batch_size, 742 sample_batch_size=args.sample_batch_size,
700 sample_num_batches=args.sample_batches, 743 sample_num_batches=args.sample_batches,
701 sample_num_steps=args.sample_steps, 744 sample_num_steps=args.sample_steps,
@@ -757,10 +800,7 @@ def main():
757 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 800 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
758 801
759 optimizer = create_optimizer( 802 optimizer = create_optimizer(
760 itertools.chain( 803 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
761 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
762 text_encoder.text_model.embeddings.overlay.parameters(),
763 ),
764 lr=args.learning_rate, 804 lr=args.learning_rate,
765 ) 805 )
766 806