diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
| commit | 208e48134e324e934ad964bdc61880cc923f4c0d (patch) | |
| tree | c215f6c201c04b0b2d18ba0df230fb4c5e622985 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.gz textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.bz2 textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.zip | |
Revert
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 52 |
1 files changed, 46 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 26ac384..5482326 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,7 +1,6 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 5 | from functools import partial | 4 | from functools import partial |
| 6 | from pathlib import Path | 5 | from pathlib import Path |
| 7 | import math | 6 | import math |
| @@ -308,6 +307,26 @@ def parse_args(): | |||
| 308 | help="Minimum learning rate in the lr scheduler." | 307 | help="Minimum learning rate in the lr scheduler." |
| 309 | ) | 308 | ) |
| 310 | parser.add_argument( | 309 | parser.add_argument( |
| 310 | "--use_ema", | ||
| 311 | action="store_true", | ||
| 312 | help="Whether to use EMA model." | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--ema_inv_gamma", | ||
| 316 | type=float, | ||
| 317 | default=1.0 | ||
| 318 | ) | ||
| 319 | parser.add_argument( | ||
| 320 | "--ema_power", | ||
| 321 | type=float, | ||
| 322 | default=4/5 | ||
| 323 | ) | ||
| 324 | parser.add_argument( | ||
| 325 | "--ema_max_decay", | ||
| 326 | type=float, | ||
| 327 | default=0.9999 | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 311 | "--optimizer", | 330 | "--optimizer", |
| 312 | type=str, | 331 | type=str, |
| 313 | default="dadan", | 332 | default="dadan", |
| @@ -334,7 +353,7 @@ def parse_args(): | |||
| 334 | parser.add_argument( | 353 | parser.add_argument( |
| 335 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 336 | type=float, | 355 | type=float, |
| 337 | default=1e-2, | 356 | default=0, |
| 338 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 339 | ) | 358 | ) |
| 340 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -432,6 +451,23 @@ def parse_args(): | |||
| 432 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 433 | ) | 452 | ) |
| 434 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--use_emb_decay", | ||
| 455 | action="store_true", | ||
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e2, | ||
| 467 | type=float, | ||
| 468 | help="Embedding decay factor." | ||
| 469 | ) | ||
| 470 | parser.add_argument( | ||
| 435 | "--noise_timesteps", | 471 | "--noise_timesteps", |
| 436 | type=int, | 472 | type=int, |
| 437 | default=1000, | 473 | default=1000, |
| @@ -696,6 +732,13 @@ def main(): | |||
| 696 | sample_scheduler=sample_scheduler, | 732 | sample_scheduler=sample_scheduler, |
| 697 | checkpoint_output_dir=checkpoint_output_dir, | 733 | checkpoint_output_dir=checkpoint_output_dir, |
| 698 | gradient_checkpointing=args.gradient_checkpointing, | 734 | gradient_checkpointing=args.gradient_checkpointing, |
| 735 | use_emb_decay=args.use_emb_decay, | ||
| 736 | emb_decay_target=args.emb_decay_target, | ||
| 737 | emb_decay=args.emb_decay, | ||
| 738 | use_ema=args.use_ema, | ||
| 739 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 740 | ema_power=args.ema_power, | ||
| 741 | ema_max_decay=args.ema_max_decay, | ||
| 699 | sample_batch_size=args.sample_batch_size, | 742 | sample_batch_size=args.sample_batch_size, |
| 700 | sample_num_batches=args.sample_batches, | 743 | sample_num_batches=args.sample_batches, |
| 701 | sample_num_steps=args.sample_steps, | 744 | sample_num_steps=args.sample_steps, |
| @@ -757,10 +800,7 @@ def main(): | |||
| 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 800 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 758 | 801 | ||
| 759 | optimizer = create_optimizer( | 802 | optimizer = create_optimizer( |
| 760 | itertools.chain( | 803 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 763 | ), | ||
| 764 | lr=args.learning_rate, | 804 | lr=args.learning_rate, |
| 765 | ) | 805 | ) |
| 766 | 806 | ||
