summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 11:50:22 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 11:50:22 +0100
commit185c6b520d2136c87b122b89b52a0cc013240c6e (patch)
tree0f27f4407d51c11cd8239f1068eb1bd0986ac45c /pipelines/stable_diffusion
parentLora fix: Save config JSON, too (diff)
downloadtextual-inversion-diff-185c6b520d2136c87b122b89b52a0cc013240c6e.tar.gz
textual-inversion-diff-185c6b520d2136c87b122b89b52a0cc013240c6e.tar.bz2
textual-inversion-diff-185c6b520d2136c87b122b89b52a0cc013240c6e.zip
Fixed Lora training perf issue
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index dbd262f..ea2a656 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
375 375
376 def decode_latents(self, latents): 376 def decode_latents(self, latents):
377 latents = 1 / self.vae.config.scaling_factor * latents 377 latents = 1 / self.vae.config.scaling_factor * latents
378 # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample 378 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
379 image = self.vae.decode(latents).sample
380 image = (image / 2 + 0.5).clamp(0, 1) 379 image = (image / 2 + 0.5).clamp(0, 1)
381 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 380 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
382 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 381 image = image.cpu().permute(0, 2, 3, 1).float().numpy()