From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/strategy/lora.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -81,7 +81,7 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if not pti_mode: accelerator.clip_grad_norm_( itertools.chain( @@ -89,7 +89,7 @@ def lora_strategy_callbacks( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), ), - max_grad_norm + max_grad_norm, ) if len(placeholder_tokens) != 0 and use_emb_decay: @@ -108,7 +108,9 @@ def lora_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) @torch.no_grad() def on_checkpoint(step, postfix): @@ -128,25 +130,32 @@ def lora_strategy_callbacks( if not pti_mode: lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + state_dict = get_peft_model_state_dict( + unet_, state_dict=accelerator.get_state_dict(unet_) + ) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + text_encoder_state_dict = { + f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() + } state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + lora_config[ + "text_encoder_peft_config" + ] = text_encoder_.get_peft_config_as_dict(inference=True) if len(placeholder_tokens) != 0: ti_state_dict = { f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) - for (token, ids) - in zip(placeholder_tokens, placeholder_token_ids) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) } state_dict.update(ti_state_dict) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + save_file( + state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" + ) with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) @@ -185,10 +194,18 @@ def lora_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs + **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) -- cgit v1.2.3-54-g00ecf