From fa4b797914bc9997039562afefd64620bf235d60 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 08:39:33 +0200 Subject: Fix --- train_lora.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/train_lora.py b/train_lora.py index 9cf17c7..d5aa78d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -511,6 +511,11 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--run_pti", + action="store_true", + help="Whether to run PTI." + ) parser.add_argument( "--emb_alpha", type=float, @@ -714,6 +719,7 @@ def main(): embeddings.persist() print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + placeholder_tokens = [] placeholder_token_ids = [] if args.embeddings_dir is not None: @@ -722,16 +728,18 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + + placeholder_tokens = added_tokens + placeholder_token_ids = added_ids + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.train_dir_embeddings: - args.placeholder_tokens = added_tokens - placeholder_token_ids = added_ids print("Training embeddings from embeddings dir") else: embeddings.persist() - if not args.train_dir_embeddings: + if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -740,8 +748,11 @@ def main(): num_vectors=args.num_vectors, initializer_noise=args.initializer_noise, ) + + placeholder_tokens = args.placeholder_tokens + stats = list(zip( - args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids + placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids )) print(f"Training embeddings: {stats}") @@ -878,7 +889,7 @@ def main(): sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, - placeholder_tokens=args.placeholder_tokens, + placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, @@ -924,8 +935,8 @@ def main(): # PTI # -------------------------------------------------------------------------------- - if len(args.placeholder_tokens) != 0: - filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] + if args.run_pti and len(placeholder_tokens) != 0: + filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] pti_datamodule = create_datamodule( valid_set_size=0, @@ -1094,7 +1105,7 @@ def main(): group_labels.append("unet") if training_iter < args.train_text_encoder_cycles: - # if len(args.placeholder_tokens) != 0: + # if len(placeholder_tokens) != 0: # params_to_optimize.append({ # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), # "lr": learning_rate_emb, -- cgit v1.2.3-70-g09d2