diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 |
| commit | 86e908656bcd7585ec45cd930176800f759f146a (patch) | |
| tree | 1169e9b1728e4c6fc8b70e46a37080ae0794ada8 /train_ti.py | |
| parent | Experimental: TI via LoRA (diff) | |
| download | textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.gz textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.bz2 textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.zip | |
Combined TI with embedding and LoRA
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 30 |
1 files changed, 5 insertions, 25 deletions
diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 4 | from functools import partial | 5 | from functools import partial |
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | import math | 7 | import math |
| @@ -307,26 +308,6 @@ def parse_args(): | |||
| 307 | help="Minimum learning rate in the lr scheduler." | 308 | help="Minimum learning rate in the lr scheduler." |
| 308 | ) | 309 | ) |
| 309 | parser.add_argument( | 310 | parser.add_argument( |
| 310 | "--use_ema", | ||
| 311 | action="store_true", | ||
| 312 | help="Whether to use EMA model." | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--ema_inv_gamma", | ||
| 316 | type=float, | ||
| 317 | default=1.0 | ||
| 318 | ) | ||
| 319 | parser.add_argument( | ||
| 320 | "--ema_power", | ||
| 321 | type=float, | ||
| 322 | default=4/5 | ||
| 323 | ) | ||
| 324 | parser.add_argument( | ||
| 325 | "--ema_max_decay", | ||
| 326 | type=float, | ||
| 327 | default=0.9999 | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 330 | "--optimizer", | 311 | "--optimizer", |
| 331 | type=str, | 312 | type=str, |
| 332 | default="dadan", | 313 | default="dadan", |
| @@ -715,10 +696,6 @@ def main(): | |||
| 715 | sample_scheduler=sample_scheduler, | 696 | sample_scheduler=sample_scheduler, |
| 716 | checkpoint_output_dir=checkpoint_output_dir, | 697 | checkpoint_output_dir=checkpoint_output_dir, |
| 717 | gradient_checkpointing=args.gradient_checkpointing, | 698 | gradient_checkpointing=args.gradient_checkpointing, |
| 718 | use_ema=args.use_ema, | ||
| 719 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 720 | ema_power=args.ema_power, | ||
| 721 | ema_max_decay=args.ema_max_decay, | ||
| 722 | sample_batch_size=args.sample_batch_size, | 699 | sample_batch_size=args.sample_batch_size, |
| 723 | sample_num_batches=args.sample_batches, | 700 | sample_num_batches=args.sample_batches, |
| 724 | sample_num_steps=args.sample_steps, | 701 | sample_num_steps=args.sample_steps, |
| @@ -780,7 +757,10 @@ def main(): | |||
| 780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 781 | 758 | ||
| 782 | optimizer = create_optimizer( | 759 | optimizer = create_optimizer( |
| 783 | text_encoder.text_model.embeddings.overlay.parameters(), | 760 | itertools.chain( |
| 761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 763 | ), | ||
| 784 | lr=args.learning_rate, | 764 | lr=args.learning_rate, |
| 785 | ) | 765 | ) |
| 786 | 766 | ||
