From 3924055ed24da9b6995303cd36282eb558ba0bf0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 14:45:37 +0200 Subject: Fix --- train_ti.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 7f5fb49..45e730a 100644 --- a/train_ti.py +++ b/train_ti.py @@ -484,19 +484,13 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--lora_r", - type=int, - default=8, - help="Lora rank, only used if use_lora is True" - ) - parser.add_argument( - "--lora_alpha", - type=int, - default=32, - help="Lora alpha, only used if use_lora is True" + "--emb_alpha", + type=float, + default=1.0, + help="Embedding alpha" ) parser.add_argument( - "--lora_dropout", + "--emb_dropout", type=float, default=0, help="Embedding dropout probability.", @@ -669,9 +663,8 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path, - args.lora_r, - args.lora_alpha, - args.lora_dropout + args.emb_alpha, + args.emb_dropout ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) -- cgit v1.2.3-54-g00ecf