diff options
author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index cc7c1ec..70fbae4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -27,6 +27,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
27 | from training.functional import train, add_placeholder_tokens, get_models | 27 | from training.functional import train, add_placeholder_tokens, get_models |
28 | from training.strategy.lora import lora_strategy | 28 | from training.strategy.lora import lora_strategy |
29 | from training.optimization import get_scheduler | 29 | from training.optimization import get_scheduler |
30 | from training.sampler import create_named_schedule_sampler | ||
30 | from training.util import AverageMeter, save_args | 31 | from training.util import AverageMeter, save_args |
31 | 32 | ||
32 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 33 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
@@ -410,6 +411,19 @@ def parse_args(): | |||
410 | help="Minimum learning rate in the lr scheduler." | 411 | help="Minimum learning rate in the lr scheduler." |
411 | ) | 412 | ) |
412 | parser.add_argument( | 413 | parser.add_argument( |
414 | "--min_snr_gamma", | ||
415 | type=int, | ||
416 | default=5, | ||
417 | help="MinSNR gamma." | ||
418 | ) | ||
419 | parser.add_argument( | ||
420 | "--schedule_sampler", | ||
421 | type=str, | ||
422 | default="uniform", | ||
423 | choices=["uniform", "loss-second-moment"], | ||
424 | help="Noise schedule sampler." | ||
425 | ) | ||
426 | parser.add_argument( | ||
413 | "--optimizer", | 427 | "--optimizer", |
414 | type=str, | 428 | type=str, |
415 | default="adan", | 429 | default="adan", |
@@ -708,6 +722,7 @@ def main(): | |||
708 | args.emb_alpha, | 722 | args.emb_alpha, |
709 | args.emb_dropout | 723 | args.emb_dropout |
710 | ) | 724 | ) |
725 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | ||
711 | 726 | ||
712 | unet_config = LoraConfig( | 727 | unet_config = LoraConfig( |
713 | r=args.lora_r, | 728 | r=args.lora_r, |
@@ -923,6 +938,8 @@ def main(): | |||
923 | tokenizer=tokenizer, | 938 | tokenizer=tokenizer, |
924 | vae=vae, | 939 | vae=vae, |
925 | noise_scheduler=noise_scheduler, | 940 | noise_scheduler=noise_scheduler, |
941 | schedule_sampler=schedule_sampler, | ||
942 | min_snr_gamma=args.min_snr_gamma, | ||
926 | dtype=weight_dtype, | 943 | dtype=weight_dtype, |
927 | seed=args.seed, | 944 | seed=args.seed, |
928 | compile_unet=args.compile_unet, | 945 | compile_unet=args.compile_unet, |