summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py3
-rw-r--r--training/strategy/lora.py15
2 files changed, 9 insertions, 9 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index dbd262f..ea2a656 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
375 375
376 def decode_latents(self, latents): 376 def decode_latents(self, latents):
377 latents = 1 / self.vae.config.scaling_factor * latents 377 latents = 1 / self.vae.config.scaling_factor * latents
378 # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample 378 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
379 image = self.vae.decode(latents).sample
380 image = (image / 2 + 0.5).clamp(0, 1) 379 image = (image / 2 + 0.5).clamp(0, 1)
381 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 380 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
382 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 381 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1e32114..8905171 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -12,6 +12,7 @@ from accelerate import Accelerator
12from transformers import CLIPTextModel 12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14from peft import get_peft_model_state_dict 14from peft import get_peft_model_state_dict
15from safetensors.torch import save_file
15 16
16from models.clip.tokenizer import MultiCLIPTokenizer 17from models.clip.tokenizer import MultiCLIPTokenizer
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
@@ -78,6 +79,9 @@ def lora_strategy_callbacks(
78 79
79 @torch.no_grad() 80 @torch.no_grad()
80 def on_checkpoint(step, postfix): 81 def on_checkpoint(step, postfix):
82 if postfix != "end":
83 return
84
81 print(f"Saving checkpoint for step {step}...") 85 print(f"Saving checkpoint for step {step}...")
82 86
83 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 87 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
@@ -94,26 +98,23 @@ def lora_strategy_callbacks(
94 state_dict.update(text_encoder_state_dict) 98 state_dict.update(text_encoder_state_dict)
95 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 99 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
96 100
97 accelerator.print(state_dict) 101 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
98 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
99 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 102 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
100 json.dump(lora_config, f) 103 json.dump(lora_config, f)
101 104
102 del unet_ 105 del unet_
103 del text_encoder_ 106 del text_encoder_
104 107
108 if torch.cuda.is_available():
109 torch.cuda.empty_cache()
110
105 @torch.no_grad() 111 @torch.no_grad()
106 def on_sample(step): 112 def on_sample(step):
107 vae_dtype = vae.dtype
108 vae.to(dtype=text_encoder.dtype)
109
110 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 113 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
111 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 114 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
112 115
113 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 116 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
114 117
115 vae.to(dtype=vae_dtype)
116
117 del unet_ 118 del unet_
118 del text_encoder_ 119 del text_encoder_
119 120