diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
| commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
| tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index ae73639..26f7941 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -23,6 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| 24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
| 25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 26 | from training.sampler import create_named_schedule_sampler | ||
| 26 | from training.util import AverageMeter, save_args | 27 | from training.util import AverageMeter, save_args |
| 27 | 28 | ||
| 28 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
| @@ -359,6 +360,19 @@ def parse_args(): | |||
| 359 | default=0.9999 | 360 | default=0.9999 |
| 360 | ) | 361 | ) |
| 361 | parser.add_argument( | 362 | parser.add_argument( |
| 363 | "--min_snr_gamma", | ||
| 364 | type=int, | ||
| 365 | default=5, | ||
| 366 | help="MinSNR gamma." | ||
| 367 | ) | ||
| 368 | parser.add_argument( | ||
| 369 | "--schedule_sampler", | ||
| 370 | type=str, | ||
| 371 | default="uniform", | ||
| 372 | choices=["uniform", "loss-second-moment"], | ||
| 373 | help="Noise schedule sampler." | ||
| 374 | ) | ||
| 375 | parser.add_argument( | ||
| 362 | "--optimizer", | 376 | "--optimizer", |
| 363 | type=str, | 377 | type=str, |
| 364 | default="adan", | 378 | default="adan", |
| @@ -682,6 +696,7 @@ def main(): | |||
| 682 | args.emb_alpha, | 696 | args.emb_alpha, |
| 683 | args.emb_dropout | 697 | args.emb_dropout |
| 684 | ) | 698 | ) |
| 699 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | ||
| 685 | 700 | ||
| 686 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 701 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 687 | tokenizer.set_dropout(args.vector_dropout) | 702 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -837,6 +852,8 @@ def main(): | |||
| 837 | tokenizer=tokenizer, | 852 | tokenizer=tokenizer, |
| 838 | vae=vae, | 853 | vae=vae, |
| 839 | noise_scheduler=noise_scheduler, | 854 | noise_scheduler=noise_scheduler, |
| 855 | schedule_sampler=schedule_sampler, | ||
| 856 | min_snr_gamma=args.min_snr_gamma, | ||
| 840 | dtype=weight_dtype, | 857 | dtype=weight_dtype, |
| 841 | seed=args.seed, | 858 | seed=args.seed, |
| 842 | compile_unet=args.compile_unet, | 859 | compile_unet=args.compile_unet, |
