diff options
author | Volpeon <git@volpeon.ink> | 2022-12-19 21:11:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-19 21:11:21 +0100 |
commit | c36b3a32964a5701edd3df1df8533cfdbb89d3cf (patch) | |
tree | fa54611729eba6a9f7b8395401d304f204c8dd64 | |
parent | Improved dataset prompt handling, fixed (diff) | |
download | textual-inversion-diff-c36b3a32964a5701edd3df1df8533cfdbb89d3cf.tar.gz textual-inversion-diff-c36b3a32964a5701edd3df1df8533cfdbb89d3cf.tar.bz2 textual-inversion-diff-c36b3a32964a5701edd3df1df8533cfdbb89d3cf.zip |
Upstream patches
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 707b639..a43a8e4 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -236,12 +236,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
236 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 236 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): |
237 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 237 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) |
238 | 238 | ||
239 | if isinstance(generator, list) and len(generator) != batch_size: | ||
240 | raise ValueError( | ||
241 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
242 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
243 | ) | ||
244 | |||
239 | if latents is None: | 245 | if latents is None: |
240 | if device.type == "mps": | 246 | rand_device = "cpu" if device.type == "mps" else device |
241 | # randn does not work reproducibly on mps | 247 | |
242 | latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | 248 | if isinstance(generator, list): |
249 | shape = (1,) + shape[1:] | ||
250 | latents = [ | ||
251 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
252 | for i in range(batch_size) | ||
253 | ] | ||
254 | latents = torch.cat(latents, dim=0).to(device) | ||
243 | else: | 255 | else: |
244 | latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | 256 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) |
245 | else: | 257 | else: |
246 | if latents.shape != shape: | 258 | if latents.shape != shape: |
247 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | 259 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
@@ -311,7 +323,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
311 | num_inference_steps: Optional[int] = 50, | 323 | num_inference_steps: Optional[int] = 50, |
312 | guidance_scale: Optional[float] = 7.5, | 324 | guidance_scale: Optional[float] = 7.5, |
313 | eta: Optional[float] = 0.0, | 325 | eta: Optional[float] = 0.0, |
314 | generator: Optional[torch.Generator] = None, | 326 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
315 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 327 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
316 | output_type: Optional[str] = "pil", | 328 | output_type: Optional[str] = "pil", |
317 | return_dict: bool = True, | 329 | return_dict: bool = True, |