diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
| commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
| tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_lora.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip | |
Update
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index cc7c1ec..70fbae4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -27,6 +27,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 27 | from training.functional import train, add_placeholder_tokens, get_models | 27 | from training.functional import train, add_placeholder_tokens, get_models |
| 28 | from training.strategy.lora import lora_strategy | 28 | from training.strategy.lora import lora_strategy |
| 29 | from training.optimization import get_scheduler | 29 | from training.optimization import get_scheduler |
| 30 | from training.sampler import create_named_schedule_sampler | ||
| 30 | from training.util import AverageMeter, save_args | 31 | from training.util import AverageMeter, save_args |
| 31 | 32 | ||
| 32 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 33 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| @@ -410,6 +411,19 @@ def parse_args(): | |||
| 410 | help="Minimum learning rate in the lr scheduler." | 411 | help="Minimum learning rate in the lr scheduler." |
| 411 | ) | 412 | ) |
| 412 | parser.add_argument( | 413 | parser.add_argument( |
| 414 | "--min_snr_gamma", | ||
| 415 | type=int, | ||
| 416 | default=5, | ||
| 417 | help="MinSNR gamma." | ||
| 418 | ) | ||
| 419 | parser.add_argument( | ||
| 420 | "--schedule_sampler", | ||
| 421 | type=str, | ||
| 422 | default="uniform", | ||
| 423 | choices=["uniform", "loss-second-moment"], | ||
| 424 | help="Noise schedule sampler." | ||
| 425 | ) | ||
| 426 | parser.add_argument( | ||
| 413 | "--optimizer", | 427 | "--optimizer", |
| 414 | type=str, | 428 | type=str, |
| 415 | default="adan", | 429 | default="adan", |
| @@ -708,6 +722,7 @@ def main(): | |||
| 708 | args.emb_alpha, | 722 | args.emb_alpha, |
| 709 | args.emb_dropout | 723 | args.emb_dropout |
| 710 | ) | 724 | ) |
| 725 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | ||
| 711 | 726 | ||
| 712 | unet_config = LoraConfig( | 727 | unet_config = LoraConfig( |
| 713 | r=args.lora_r, | 728 | r=args.lora_r, |
| @@ -923,6 +938,8 @@ def main(): | |||
| 923 | tokenizer=tokenizer, | 938 | tokenizer=tokenizer, |
| 924 | vae=vae, | 939 | vae=vae, |
| 925 | noise_scheduler=noise_scheduler, | 940 | noise_scheduler=noise_scheduler, |
| 941 | schedule_sampler=schedule_sampler, | ||
| 942 | min_snr_gamma=args.min_snr_gamma, | ||
| 926 | dtype=weight_dtype, | 943 | dtype=weight_dtype, |
| 927 | seed=args.seed, | 944 | seed=args.seed, |
| 928 | compile_unet=args.compile_unet, | 945 | compile_unet=args.compile_unet, |
