summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/strategy/lora.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1e32114..8905171 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -12,6 +12,7 @@ from accelerate import Accelerator
12from transformers import CLIPTextModel 12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14from peft import get_peft_model_state_dict 14from peft import get_peft_model_state_dict
15from safetensors.torch import save_file
15 16
16from models.clip.tokenizer import MultiCLIPTokenizer 17from models.clip.tokenizer import MultiCLIPTokenizer
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
@@ -78,6 +79,9 @@ def lora_strategy_callbacks(
78 79
79 @torch.no_grad() 80 @torch.no_grad()
80 def on_checkpoint(step, postfix): 81 def on_checkpoint(step, postfix):
82 if postfix != "end":
83 return
84
81 print(f"Saving checkpoint for step {step}...") 85 print(f"Saving checkpoint for step {step}...")
82 86
83 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 87 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
@@ -94,26 +98,23 @@ def lora_strategy_callbacks(
94 state_dict.update(text_encoder_state_dict) 98 state_dict.update(text_encoder_state_dict)
95 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 99 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
96 100
97 accelerator.print(state_dict) 101 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
98 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
99 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 102 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
100 json.dump(lora_config, f) 103 json.dump(lora_config, f)
101 104
102 del unet_ 105 del unet_
103 del text_encoder_ 106 del text_encoder_
104 107
108 if torch.cuda.is_available():
109 torch.cuda.empty_cache()
110
105 @torch.no_grad() 111 @torch.no_grad()
106 def on_sample(step): 112 def on_sample(step):
107 vae_dtype = vae.dtype
108 vae.to(dtype=text_encoder.dtype)
109
110 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 113 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
111 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 114 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
112 115
113 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 116 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
114 117
115 vae.to(dtype=vae_dtype)
116
117 del unet_ 118 del unet_
118 del text_encoder_ 119 del text_encoder_
119 120