diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
commit | 89d471652644f449966a0cd944041c98dab7f66c (patch) | |
tree | 4cc797369a5c781b4978b89a61023c4de7fde606 /pipelines | |
parent | Update (diff) | |
download | textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.gz textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.bz2 textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.zip |
Code deduplication
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 32 |
1 files changed, 9 insertions, 23 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,7 +20,7 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.utils import logging | 23 | from diffusers.utils import logging, randn_tensor |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
26 | 26 | ||
@@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
250 | 250 | ||
251 | return timesteps | 251 | return timesteps |
252 | 252 | ||
253 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 253 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
254 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 254 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
255 | 255 | ||
256 | if isinstance(generator, list) and len(generator) != batch_size: | 256 | if isinstance(generator, list) and len(generator) != batch_size: |
257 | raise ValueError( | 257 | raise ValueError( |
@@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
260 | ) | 260 | ) |
261 | 261 | ||
262 | if latents is None: | 262 | if latents is None: |
263 | rand_device = "cpu" if device.type == "mps" else device | 263 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
264 | |||
265 | if isinstance(generator, list): | ||
266 | shape = (1,) + shape[1:] | ||
267 | latents = [ | ||
268 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
269 | for i in range(batch_size) | ||
270 | ] | ||
271 | latents = torch.cat(latents, dim=0).to(device) | ||
272 | else: | ||
273 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) | ||
274 | else: | 264 | else: |
275 | if latents.shape != shape: | 265 | latents = latents.to(device=device, dtype=dtype) |
276 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
277 | latents = latents.to(device) | ||
278 | 266 | ||
279 | # scale the initial noise by the standard deviation required by the scheduler | 267 | # scale the initial noise by the standard deviation required by the scheduler |
280 | latents = latents * self.scheduler.init_noise_sigma | 268 | latents = latents * self.scheduler.init_noise_sigma |
281 | 269 | ||
282 | return latents | 270 | return latents |
283 | 271 | ||
284 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | 272 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
285 | init_image = init_image.to(device=device, dtype=dtype) | 273 | init_image = init_image.to(device=device, dtype=dtype) |
286 | init_latent_dist = self.vae.encode(init_image).latent_dist | 274 | init_latent_dist = self.vae.encode(init_image).latent_dist |
287 | init_latents = init_latent_dist.sample(generator=generator) | 275 | init_latents = init_latent_dist.sample(generator=generator) |
@@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
292 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 280 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
293 | ) | 281 | ) |
294 | else: | 282 | else: |
295 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | 283 | init_latents = torch.cat([init_latents] * batch_size, dim=0) |
296 | 284 | ||
297 | # add noise to latents using the timesteps | 285 | # add noise to latents using the timesteps |
298 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 286 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) |
@@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
430 | latents = self.prepare_latents_from_image( | 418 | latents = self.prepare_latents_from_image( |
431 | image, | 419 | image, |
432 | latent_timestep, | 420 | latent_timestep, |
433 | batch_size, | 421 | batch_size * num_images_per_prompt, |
434 | num_images_per_prompt, | ||
435 | text_embeddings.dtype, | 422 | text_embeddings.dtype, |
436 | device, | 423 | device, |
437 | generator | 424 | generator |
438 | ) | 425 | ) |
439 | else: | 426 | else: |
440 | latents = self.prepare_latents( | 427 | latents = self.prepare_latents( |
441 | batch_size, | 428 | batch_size * num_images_per_prompt, |
442 | num_images_per_prompt, | ||
443 | num_channels_latents, | 429 | num_channels_latents, |
444 | height, | 430 | height, |
445 | width, | 431 | width, |