From 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 07:58:14 +0200 Subject: Update --- train_lora.py | 75 ++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 41 insertions(+), 34 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -19,7 +19,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models -from training.lr import plot_metrics from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.util import save_args @@ -568,6 +567,9 @@ def parse_args(): if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + if len(args.initializer_tokens) == 0: + args.initializer_tokens = args.placeholder_tokens.copy() + if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") @@ -918,7 +920,7 @@ def main(): train_epochs=num_pti_epochs, ) - metrics = trainer( + trainer( strategy=lora_strategy, pti_mode=True, project="pti", @@ -929,12 +931,13 @@ def main(): num_train_epochs=num_pti_epochs, gradient_accumulation_steps=args.pti_gradient_accumulation_steps, # -- + group_labels=["emb"], sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, ) - plot_metrics(metrics, pti_output_dir / "lr.png") + # embeddings.persist() # LORA # -------------------------------------------------------------------------------- @@ -957,34 +960,39 @@ def main(): ) * args.gradient_accumulation_steps lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) - lora_optimizer = create_optimizer( - [ - { - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), - "lr": args.learning_rate_unet, - }, - { - "params": ( - param - for param in itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - if param.requires_grad - ), - "lr": args.learning_rate_text, - }, - { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_text, - "weight_decay": 0, - }, - ] - ) + params_to_optimize = [] + group_labels = [] + if len(args.placeholder_tokens) != 0: + params_to_optimize.append({ + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_text, + "weight_decay": 0, + }) + group_labels.append("emb") + params_to_optimize += [ + { + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), + "lr": args.learning_rate_unet, + }, + { + "params": ( + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad + ), + "lr": args.learning_rate_text, + }, + ] + group_labels += ["unet", "text"] + + lora_optimizer = create_optimizer(params_to_optimize) lora_lr_scheduler = create_lr_scheduler( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -993,7 +1001,7 @@ def main(): train_epochs=num_train_epochs, ) - metrics = trainer( + trainer( strategy=lora_strategy, project="lora", train_dataloader=lora_datamodule.train_dataloader, @@ -1003,13 +1011,12 @@ def main(): num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, # -- + group_labels=group_labels, sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, ) - plot_metrics(metrics, lora_output_dir / "lr.png") - if __name__ == "__main__": main() -- cgit v1.2.3-70-g09d2