summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
committerVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
commit7b04d813739c0b5595295dffdc86cc41108db2d3 (patch)
tree8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index ae73639..26f7941 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -23,6 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.ti import textual_inversion_strategy 24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.sampler import create_named_schedule_sampler
26from training.util import AverageMeter, save_args 27from training.util import AverageMeter, save_args
27 28
28logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -359,6 +360,19 @@ def parse_args():
359 default=0.9999 360 default=0.9999
360 ) 361 )
361 parser.add_argument( 362 parser.add_argument(
363 "--min_snr_gamma",
364 type=int,
365 default=5,
366 help="MinSNR gamma."
367 )
368 parser.add_argument(
369 "--schedule_sampler",
370 type=str,
371 default="uniform",
372 choices=["uniform", "loss-second-moment"],
373 help="Noise schedule sampler."
374 )
375 parser.add_argument(
362 "--optimizer", 376 "--optimizer",
363 type=str, 377 type=str,
364 default="adan", 378 default="adan",
@@ -682,6 +696,7 @@ def main():
682 args.emb_alpha, 696 args.emb_alpha,
683 args.emb_dropout 697 args.emb_dropout
684 ) 698 )
699 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
685 700
686 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 701 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
687 tokenizer.set_dropout(args.vector_dropout) 702 tokenizer.set_dropout(args.vector_dropout)
@@ -837,6 +852,8 @@ def main():
837 tokenizer=tokenizer, 852 tokenizer=tokenizer,
838 vae=vae, 853 vae=vae,
839 noise_scheduler=noise_scheduler, 854 noise_scheduler=noise_scheduler,
855 schedule_sampler=schedule_sampler,
856 min_snr_gamma=args.min_snr_gamma,
840 dtype=weight_dtype, 857 dtype=weight_dtype,
841 seed=args.seed, 858 seed=args.seed,
842 compile_unet=args.compile_unet, 859 compile_unet=args.compile_unet,