From 4ca5d3bf8710c479dab994a7243b2dd9232bab3d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 17:42:23 +0200 Subject: Fixed Lora PTI --- train_lora.py | 37 ++++++++++++++++++++----------------- training/strategy/lora.py | 35 +++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/train_lora.py b/train_lora.py index 476efcf..5b0a292 100644 --- a/train_lora.py +++ b/train_lora.py @@ -21,7 +21,6 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.lr import plot_metrics from training.strategy.lora import lora_strategy -from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -829,6 +828,12 @@ def main(): sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + max_grad_norm=args.max_grad_norm, ) create_datamodule = partial( @@ -907,7 +912,8 @@ def main(): ) metrics = trainer( - strategy=textual_inversion_strategy, + strategy=lora_strategy, + pti_mode=True, project="pti", train_dataloader=pti_datamodule.train_dataloader, val_dataloader=pti_datamodule.val_dataloader, @@ -919,11 +925,6 @@ def main(): sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=math.inf, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, ) plot_metrics(metrics, pti_output_dir / "lr.png") @@ -952,13 +953,21 @@ def main(): lora_optimizer = create_optimizer( [ { - "params": unet.parameters(), + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), "lr": args.learning_rate_unet, }, { - "params": itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), + "params": ( + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad ), "lr": args.learning_rate_text, }, @@ -990,12 +999,6 @@ def main(): sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, - max_grad_norm=args.max_grad_norm, ) plot_metrics(metrics, lora_output_dir / "lr.png") diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 80ffa9c..912ff26 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -34,6 +34,7 @@ def lora_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], + pti_mode: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -79,10 +80,11 @@ def lora_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), - max_grad_norm - ) + if not pti_mode: + accelerator.clip_grad_norm_( + itertools.chain(unet.parameters(), text_encoder.parameters()), + max_grad_norm + ) if use_emb_decay: params = [ @@ -117,20 +119,21 @@ def lora_strategy_callbacks( checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) - lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) - lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) + if not pti_mode: + lora_config = {} + state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) - text_encoder_state_dict = get_peft_model_state_dict( - text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) - ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} - state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + text_encoder_state_dict = get_peft_model_state_dict( + text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) + ) + text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + state_dict.update(text_encoder_state_dict) + lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") - with open(checkpoint_output_dir / "lora_config.json", "w") as f: - json.dump(lora_config, f) + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + with open(checkpoint_output_dir / "lora_config.json", "w") as f: + json.dump(lora_config, f) del unet_ del text_encoder_ -- cgit v1.2.3-70-g09d2