From 185c6b520d2136c87b122b89b52a0cc013240c6e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 11:50:22 +0100 Subject: Fixed Lora training perf issue --- training/strategy/lora.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1e32114..8905171 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -12,6 +12,7 @@ from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from peft import get_peft_model_state_dict +from safetensors.torch import save_file from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -78,6 +79,9 @@ def lora_strategy_callbacks( @torch.no_grad() def on_checkpoint(step, postfix): + if postfix != "end": + return + print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) @@ -94,26 +98,23 @@ def lora_strategy_callbacks( state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) - accelerator.print(state_dict) - accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) del unet_ del text_encoder_ + if torch.cuda.is_available(): + torch.cuda.empty_cache() + @torch.no_grad() def on_sample(step): - vae_dtype = vae.dtype - vae.to(dtype=text_encoder.dtype) - unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) - vae.to(dtype=vae_dtype) - del unet_ del text_encoder_ -- cgit v1.2.3-70-g09d2