From 4ca5d3bf8710c479dab994a7243b2dd9232bab3d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 17:42:23 +0200 Subject: Fixed Lora PTI --- training/strategy/lora.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) (limited to 'training') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 80ffa9c..912ff26 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -34,6 +34,7 @@ def lora_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], + pti_mode: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -79,10 +80,11 @@ def lora_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), - max_grad_norm - ) + if not pti_mode: + accelerator.clip_grad_norm_( + itertools.chain(unet.parameters(), text_encoder.parameters()), + max_grad_norm + ) if use_emb_decay: params = [ @@ -117,20 +119,21 @@ def lora_strategy_callbacks( checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) - lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) - lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) + if not pti_mode: + lora_config = {} + state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) - text_encoder_state_dict = get_peft_model_state_dict( - text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) - ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} - state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + text_encoder_state_dict = get_peft_model_state_dict( + text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) + ) + text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + state_dict.update(text_encoder_state_dict) + lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") - with open(checkpoint_output_dir / "lora_config.json", "w") as f: - json.dump(lora_config, f) + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + with open(checkpoint_output_dir / "lora_config.json", "w") as f: + json.dump(lora_config, f) del unet_ del text_encoder_ -- cgit v1.2.3-70-g09d2