diff options
author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index ae73639..26f7941 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -23,6 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
26 | from training.sampler import create_named_schedule_sampler | ||
26 | from training.util import AverageMeter, save_args | 27 | from training.util import AverageMeter, save_args |
27 | 28 | ||
28 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
@@ -359,6 +360,19 @@ def parse_args(): | |||
359 | default=0.9999 | 360 | default=0.9999 |
360 | ) | 361 | ) |
361 | parser.add_argument( | 362 | parser.add_argument( |
363 | "--min_snr_gamma", | ||
364 | type=int, | ||
365 | default=5, | ||
366 | help="MinSNR gamma." | ||
367 | ) | ||
368 | parser.add_argument( | ||
369 | "--schedule_sampler", | ||
370 | type=str, | ||
371 | default="uniform", | ||
372 | choices=["uniform", "loss-second-moment"], | ||
373 | help="Noise schedule sampler." | ||
374 | ) | ||
375 | parser.add_argument( | ||
362 | "--optimizer", | 376 | "--optimizer", |
363 | type=str, | 377 | type=str, |
364 | default="adan", | 378 | default="adan", |
@@ -682,6 +696,7 @@ def main(): | |||
682 | args.emb_alpha, | 696 | args.emb_alpha, |
683 | args.emb_dropout | 697 | args.emb_dropout |
684 | ) | 698 | ) |
699 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | ||
685 | 700 | ||
686 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 701 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
687 | tokenizer.set_dropout(args.vector_dropout) | 702 | tokenizer.set_dropout(args.vector_dropout) |
@@ -837,6 +852,8 @@ def main(): | |||
837 | tokenizer=tokenizer, | 852 | tokenizer=tokenizer, |
838 | vae=vae, | 853 | vae=vae, |
839 | noise_scheduler=noise_scheduler, | 854 | noise_scheduler=noise_scheduler, |
855 | schedule_sampler=schedule_sampler, | ||
856 | min_snr_gamma=args.min_snr_gamma, | ||
840 | dtype=weight_dtype, | 857 | dtype=weight_dtype, |
841 | seed=args.seed, | 858 | seed=args.seed, |
842 | compile_unet=args.compile_unet, | 859 | compile_unet=args.compile_unet, |