From 7b04d813739c0b5595295dffdc86cc41108db2d3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 6 May 2023 16:25:36 +0200 Subject: Update --- train_ti.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index ae73639..26f7941 100644 --- a/train_ti.py +++ b/train_ti.py @@ -23,6 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler +from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args logger = get_logger(__name__) @@ -358,6 +359,19 @@ def parse_args(): type=float, default=0.9999 ) + parser.add_argument( + "--min_snr_gamma", + type=int, + default=5, + help="MinSNR gamma." + ) + parser.add_argument( + "--schedule_sampler", + type=str, + default="uniform", + choices=["uniform", "loss-second-moment"], + help="Noise schedule sampler." + ) parser.add_argument( "--optimizer", type=str, @@ -682,6 +696,7 @@ def main(): args.emb_alpha, args.emb_dropout ) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -837,6 +852,8 @@ def main(): tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, + schedule_sampler=schedule_sampler, + min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, -- cgit v1.2.3-54-g00ecf