From 7b04d813739c0b5595295dffdc86cc41108db2d3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 6 May 2023 16:25:36 +0200 Subject: Update --- train_lora.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index cc7c1ec..70fbae4 100644 --- a/train_lora.py +++ b/train_lora.py @@ -27,6 +27,7 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.strategy.lora import lora_strategy from training.optimization import get_scheduler +from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py @@ -409,6 +410,19 @@ def parse_args(): default=0.04, help="Minimum learning rate in the lr scheduler." ) + parser.add_argument( + "--min_snr_gamma", + type=int, + default=5, + help="MinSNR gamma." + ) + parser.add_argument( + "--schedule_sampler", + type=str, + default="uniform", + choices=["uniform", "loss-second-moment"], + help="Noise schedule sampler." + ) parser.add_argument( "--optimizer", type=str, @@ -708,6 +722,7 @@ def main(): args.emb_alpha, args.emb_dropout ) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) unet_config = LoraConfig( r=args.lora_r, @@ -923,6 +938,8 @@ def main(): tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, + schedule_sampler=schedule_sampler, + min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, -- cgit v1.2.3-54-g00ecf