summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
commit64c594869135354a38353551bd58a93e15bd5b85 (patch)
tree2bcc085a396824f78e58c90b1f6e9553c7f5c8c1 /pipelines
parentFix img2img (diff)
downloadtextual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.gz
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.bz2
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.zip
Small performance improvements
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index b4c85e9..8fbe5f9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline):
223 # Unlike in other pipelines, latents need to be generated in the target device 223 # Unlike in other pipelines, latents need to be generated in the target device
224 # for 1-to-1 results reproducibility with the CompVis implementation. 224 # for 1-to-1 results reproducibility with the CompVis implementation.
225 # However this currently doesn't work in `mps`. 225 # However this currently doesn't work in `mps`.
226 latents_device = "cpu" if self.device.type == "mps" else self.device 226 latents_dtype = text_embeddings.dtype
227 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) 227 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
228 if latents is None: 228 if latents is None:
229 latents = torch.randn( 229 if self.device.type == "mps":
230 latents_shape, 230 # randn does not exist on mps
231 generator=generator, 231 latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
232 device=latents_device, 232 self.device
233 dtype=text_embeddings.dtype, 233 )
234 ) 234 else:
235 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
235 elif isinstance(latents, PIL.Image.Image): 236 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height) 237 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 238 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
@@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
259 else: 260 else:
260 if latents.shape != latents_shape: 261 if latents.shape != latents_shape:
261 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 262 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
262 latents = latents.to(self.device) 263 if latents.device != self.device:
264 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}")
263 265
264 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 266 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265 if ensure_sigma: 267 if ensure_sigma: