From 3924055ed24da9b6995303cd36282eb558ba0bf0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 14:45:37 +0200 Subject: Fix --- train_lora.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index d5dde02..5c78664 100644 --- a/train_lora.py +++ b/train_lora.py @@ -506,6 +506,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--emb_alpha", + type=float, + default=1.0, + help="Embedding alpha" + ) parser.add_argument( "--emb_dropout", type=float, @@ -660,7 +666,10 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, args.emb_dropout) + args.pretrained_model_name_or_path, + args.emb_alpha, + args.emb_dropout + ) unet_config = LoraConfig( r=args.lora_r, -- cgit v1.2.3-54-g00ecf