summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
committerVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
commit7b04d813739c0b5595295dffdc86cc41108db2d3 (patch)
tree8958b612f5d3d665866770ad553e1004aa4b6fb8 /train_lora.py
parentUpdate (diff)
downloadtextual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index cc7c1ec..70fbae4 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -27,6 +27,7 @@ from data.csv import VlpnDataModule, keyword_filter
27from training.functional import train, add_placeholder_tokens, get_models 27from training.functional import train, add_placeholder_tokens, get_models
28from training.strategy.lora import lora_strategy 28from training.strategy.lora import lora_strategy
29from training.optimization import get_scheduler 29from training.optimization import get_scheduler
30from training.sampler import create_named_schedule_sampler
30from training.util import AverageMeter, save_args 31from training.util import AverageMeter, save_args
31 32
32# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 33# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
@@ -410,6 +411,19 @@ def parse_args():
410 help="Minimum learning rate in the lr scheduler." 411 help="Minimum learning rate in the lr scheduler."
411 ) 412 )
412 parser.add_argument( 413 parser.add_argument(
414 "--min_snr_gamma",
415 type=int,
416 default=5,
417 help="MinSNR gamma."
418 )
419 parser.add_argument(
420 "--schedule_sampler",
421 type=str,
422 default="uniform",
423 choices=["uniform", "loss-second-moment"],
424 help="Noise schedule sampler."
425 )
426 parser.add_argument(
413 "--optimizer", 427 "--optimizer",
414 type=str, 428 type=str,
415 default="adan", 429 default="adan",
@@ -708,6 +722,7 @@ def main():
708 args.emb_alpha, 722 args.emb_alpha,
709 args.emb_dropout 723 args.emb_dropout
710 ) 724 )
725 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
711 726
712 unet_config = LoraConfig( 727 unet_config = LoraConfig(
713 r=args.lora_r, 728 r=args.lora_r,
@@ -923,6 +938,8 @@ def main():
923 tokenizer=tokenizer, 938 tokenizer=tokenizer,
924 vae=vae, 939 vae=vae,
925 noise_scheduler=noise_scheduler, 940 noise_scheduler=noise_scheduler,
941 schedule_sampler=schedule_sampler,
942 min_snr_gamma=args.min_snr_gamma,
926 dtype=weight_dtype, 943 dtype=weight_dtype,
927 seed=args.seed, 944 seed=args.seed,
928 compile_unet=args.compile_unet, 945 compile_unet=args.compile_unet,