From 050dcbde0483c277cd632e05d8a0f73c87332785 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 20 Apr 2023 12:29:58 +0200 Subject: Update --- train_lora.py | 145 ++++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 29 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index d0313fe..0ae8b31 100644 --- a/train_lora.py +++ b/train_lora.py @@ -302,6 +302,11 @@ def parse_args(): default="none", help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", ) + parser.add_argument( + "--train_text_encoder_cycles", + default=999999, + help="Number of epochs the text encoder will be trained." + ) parser.add_argument( "--find_lr", action="store_true", @@ -919,6 +924,78 @@ def main(): mid_point=args.lr_mid_point, ) + # PTI + # -------------------------------------------------------------------------------- + + if len(args.placeholder_tokens) != 0: + filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] + + pti_datamodule = create_datamodule( + batch_size=args.train_batch_size, + filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), + ) + pti_datamodule.setup() + + num_train_epochs = args.num_train_epochs + pti_sample_frequency = args.sample_frequency + if num_train_epochs is None: + num_train_epochs = math.ceil( + args.num_train_steps / len(pti_datamodule.train_dataset) + ) * args.gradient_accumulation_steps + pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) + num_training_steps_per_epoch = math.ceil( + len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_steps = num_training_steps_per_epoch * num_train_epochs + if args.sample_num is not None: + pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) + + pti_project = "pti" + + if accelerator.is_main_process: + accelerator.init_trackers(pti_project) + + pti_sample_output_dir = output_dir / pti_project / "samples" + + print("") + print(f"============ PTI ============") + print("") + + pti_optimizer = create_optimizer([{ + "params": text_encoder.text_model.embeddings.token_embedding.parameters(), + "lr": args.learning_rate_emb, + "weight_decay": 0, + }]) + + pti_lr_scheduler = create_lr_scheduler( + "constant_with_warmup", + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=pti_optimizer, + num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), + train_epochs=num_train_epochs, + warmup_epochs=math.ceil(0.1 * num_train_epochs), + ) + + pti_checkpoint_output_dir = output_dir / pti_project / "model" + + trainer( + strategy=lora_strategy, + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_train_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + cycle=1, + pti_mode=True, + # -- + group_labels=["emb"], + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + ) + + embeddings.persist() + # LORA # -------------------------------------------------------------------------------- @@ -941,16 +1018,6 @@ def main(): if args.sample_num is not None: lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) - group_labels = [] - if len(args.placeholder_tokens) != 0: - group_labels.append("emb") - group_labels += ["unet", "text"] - - training_iter = 0 - learning_rate_emb = args.learning_rate_emb - learning_rate_unet = args.learning_rate_unet - learning_rate_text = args.learning_rate_text - lora_project = "lora" if accelerator.is_main_process: @@ -958,7 +1025,11 @@ def main(): lora_sample_output_dir = output_dir / lora_project / "samples" + training_iter = 0 auto_cycles = list(args.auto_cycles) + learning_rate_emb = args.learning_rate_emb + learning_rate_unet = args.learning_rate_unet + learning_rate_text = args.learning_rate_text lr_scheduler = args.lr_scheduler lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles @@ -969,6 +1040,18 @@ def main(): else: response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + if response.lower().strip() == "o": + if args.learning_rate_emb is not None: + learning_rate_emb = args.learning_rate_emb * 2 + if args.learning_rate_unet is not None: + learning_rate_unet = args.learning_rate_unet * 2 + if args.learning_rate_text is not None: + learning_rate_text = args.learning_rate_text * 2 + else: + learning_rate_emb = args.learning_rate_emb + learning_rate_unet = args.learning_rate_unet + learning_rate_text = args.learning_rate_text + if response.lower().strip() == "o": lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs @@ -986,28 +1069,32 @@ def main(): break print("") - print(f"============ LoRA cycle {training_iter + 1} ============") + print(f"============ LoRA cycle {training_iter + 1}: {response} ============") print("") params_to_optimize = [] + group_labels = [] + + params_to_optimize.append({ + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), + "lr": learning_rate_unet, + }) + group_labels.append("unet") + + if training_iter < args.train_text_encoder_cycles: + # if len(args.placeholder_tokens) != 0: + # params_to_optimize.append({ + # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), + # "lr": learning_rate_emb, + # "weight_decay": 0, + # }) + # group_labels.append("emb") - if len(args.placeholder_tokens) != 0: params_to_optimize.append({ - "params": text_encoder.text_model.embeddings.token_embedding.parameters(), - "lr": learning_rate_emb, - "weight_decay": 0, - }) - group_labels.append("emb") - params_to_optimize += [ - { - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), - "lr": learning_rate_unet, - }, - { "params": ( param for param in itertools.chain( @@ -1017,8 +1104,8 @@ def main(): if param.requires_grad ), "lr": learning_rate_text, - }, - ] + }) + group_labels.append("text") lora_optimizer = create_optimizer(params_to_optimize) -- cgit v1.2.3-54-g00ecf