From ba9fd1a10746d85d2502c8a79ac49db63d346b04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 11:29:31 +0200 Subject: Update --- train_ti.py | 66 +++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 24 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index daf8bc5..2d51800 100644 --- a/train_ti.py +++ b/train_ti.py @@ -457,6 +457,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--emb_dropout", + type=float, + default=0, + help="Embedding dropout probability.", + ) parser.add_argument( "--use_emb_decay", action="store_true", @@ -624,7 +630,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_dropout) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -755,8 +761,6 @@ def main(): else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") - checkpoint_output_dir = output_dir / "checkpoints" - trainer = partial( train, accelerator=accelerator, @@ -777,7 +781,6 @@ def main(): global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, # -- - checkpoint_output_dir=checkpoint_output_dir, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, emb_decay=args.emb_decay, @@ -793,11 +796,6 @@ def main(): ) def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): - if len(placeholder_tokens) == 1: - sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" - else: - sample_output_dir = output_dir / "samples" - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -809,7 +807,11 @@ def main(): stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) - print(f"{i + 1}: {stats}") + print("") + print(f"============ TI batch {i + 1} ============") + print("") + print(stats) + print("") filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] @@ -868,20 +870,36 @@ def main(): mid_point=args.lr_mid_point, ) - trainer( - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=num_train_epochs, - # -- - group_labels=["emb"], - sample_output_dir=sample_output_dir, - sample_frequency=sample_frequency, - placeholder_tokens=placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - ) + continue_training = True + training_iter = 1 + + while continue_training: + print(f"------------ TI cycle {training_iter} ------------") + print("") + + project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" + sample_output_dir = output_dir / project / "samples" + checkpoint_output_dir = output_dir / project / "checkpoints" + + trainer( + project=project, + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=num_train_epochs, + # -- + group_labels=["emb"], + checkpoint_output_dir=checkpoint_output_dir, + sample_output_dir=sample_output_dir, + sample_frequency=sample_frequency, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + ) + + response = input("Run another cycle? [y/n] ") + continue_training = response.lower().strip() != "n" + training_iter += 1 if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) -- cgit v1.2.3-54-g00ecf