From eb6a92abda5893c975437026cdaf0ce0bfefe2a4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 10:34:12 +0200 Subject: Update --- train_ti.py | 68 +++++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 27 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index ebac302..eb08bda 100644 --- a/train_ti.py +++ b/train_ti.py @@ -151,6 +151,11 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--train_dir_embeddings", + action="store_true", + help="Train embeddings loaded from embeddings directory.", + ) parser.add_argument( "--collection", type=str, @@ -403,6 +408,12 @@ def parse_args(): action='store_true', help="If checkpoints are saved on maximum accuracy", ) + parser.add_argument( + "--sample_num", + type=int, + default=None, + help="How often to save a checkpoint and sample image (in number of samples)", + ) parser.add_argument( "--sample_frequency", type=int, @@ -669,9 +680,14 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + if args.train_dir_embeddings: + args.placeholder_tokens = added_tokens + print("Training embeddings from embeddings dir") + else: + embeddings.persist() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -852,28 +868,8 @@ def main(): args.num_train_steps / len(datamodule.train_dataset) ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), - lr=args.learning_rate, - ) - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - end_lr=1e3, - train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - mid_point=args.lr_mid_point, - ) + if args.sample_num is not None: + sample_frequency = math.ceil(num_train_epochs / args.sample_num) training_iter = 0 @@ -888,6 +884,28 @@ def main(): print(f"------------ TI cycle {training_iter} ------------") print("") + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_override_embedding.parameters(), + lr=args.learning_rate, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e3, + train_epochs=num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, + ) + project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" sample_output_dir = output_dir / project / "samples" checkpoint_output_dir = output_dir / project / "checkpoints" @@ -908,10 +926,6 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) - response = input("Run another cycle? [y/n] ") - continue_training = response.lower().strip() != "n" - training_iter += 1 - if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: -- cgit v1.2.3-54-g00ecf