From c40170386fd055f715db90886f0ac0da5c575bd9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 13:19:05 +0200 Subject: Fix TI --- train_ti.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 9ae8d1b..e4fd464 100644 --- a/train_ti.py +++ b/train_ti.py @@ -588,14 +588,6 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - if args.embeddings_dir is not None: - embeddings_dir = Path(args.embeddings_dir) - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - raise ValueError("--embeddings_dir must point to an existing directory") - - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] @@ -609,6 +601,14 @@ def main(): embeddings.persist() print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * -- cgit v1.2.3-54-g00ecf