From 37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 11:31:21 +0200 Subject: Run PTI only if placeholder tokens arg isn't empty --- train_lora.py | 109 +++++++++++++++++++++++++++++----------------------------- 1 file changed, 55 insertions(+), 54 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 6de3a75..daf1f6c 100644 --- a/train_lora.py +++ b/train_lora.py @@ -867,62 +867,63 @@ def main(): # PTI # -------------------------------------------------------------------------------- - pti_output_dir = output_dir / "pti" - pti_checkpoint_output_dir = pti_output_dir / "model" - pti_sample_output_dir = pti_output_dir / "samples" - - pti_datamodule = create_datamodule( - batch_size=args.pti_batch_size, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), - ) - pti_datamodule.setup() - - num_pti_epochs = args.num_pti_epochs - pti_sample_frequency = args.sample_frequency - if num_pti_epochs is None: - num_pti_epochs = math.ceil( - args.num_pti_steps / len(pti_datamodule.train_dataset) - ) * args.pti_gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) - - pti_optimizer = create_optimizer( - [ - { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_pti, - "weight_decay": 0, - }, - ] - ) + if len(args.placeholder_tokens) != 0: + pti_output_dir = output_dir / "pti" + pti_checkpoint_output_dir = pti_output_dir / "model" + pti_sample_output_dir = pti_output_dir / "samples" + + pti_datamodule = create_datamodule( + batch_size=args.pti_batch_size, + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + ) + pti_datamodule.setup() + + num_pti_epochs = args.num_pti_epochs + pti_sample_frequency = args.sample_frequency + if num_pti_epochs is None: + num_pti_epochs = math.ceil( + args.num_pti_steps / len(pti_datamodule.train_dataset) + ) * args.pti_gradient_accumulation_steps + pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) + + pti_optimizer = create_optimizer( + [ + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_pti, + "weight_decay": 0, + }, + ] + ) - pti_lr_scheduler = create_lr_scheduler( - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - optimizer=pti_optimizer, - num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), - train_epochs=num_pti_epochs, - ) + pti_lr_scheduler = create_lr_scheduler( + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, + optimizer=pti_optimizer, + num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), + train_epochs=num_pti_epochs, + ) - metrics = trainer( - strategy=textual_inversion_strategy, - project="pti", - train_dataloader=pti_datamodule.train_dataloader, - val_dataloader=pti_datamodule.val_dataloader, - optimizer=pti_optimizer, - lr_scheduler=pti_lr_scheduler, - num_train_epochs=num_pti_epochs, - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - # -- - sample_output_dir=pti_sample_output_dir, - checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, - ) + metrics = trainer( + strategy=textual_inversion_strategy, + project="pti", + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_pti_epochs, + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, + # -- + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + ) - plot_metrics(metrics, output_dir/"lr.png") + plot_metrics(metrics, pti_output_dir / "lr.png") # LORA # -------------------------------------------------------------------------------- @@ -994,7 +995,7 @@ def main(): max_grad_norm=args.max_grad_norm, ) - plot_metrics(metrics, output_dir/"lr.png") + plot_metrics(metrics, lora_output_dir / "lr.png") if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf