summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 07:25:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 07:25:24 +0100
commit89d471652644f449966a0cd944041c98dab7f66c (patch)
tree4cc797369a5c781b4978b89a61023c4de7fde606 /pipelines
parentUpdate (diff)
downloadtextual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.gz
textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.bz2
textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.zip
Code deduplication
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py32
1 files changed, 9 insertions, 23 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cb300d1..6bc40e9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,7 +20,7 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.utils import logging 23from diffusers.utils import logging, randn_tensor
24from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
26 26
@@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
250 250
251 return timesteps 251 return timesteps
252 252
253 def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): 253 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
254 shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) 254 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
255 255
256 if isinstance(generator, list) and len(generator) != batch_size: 256 if isinstance(generator, list) and len(generator) != batch_size:
257 raise ValueError( 257 raise ValueError(
@@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline):
260 ) 260 )
261 261
262 if latents is None: 262 if latents is None:
263 rand_device = "cpu" if device.type == "mps" else device 263 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
264
265 if isinstance(generator, list):
266 shape = (1,) + shape[1:]
267 latents = [
268 torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
269 for i in range(batch_size)
270 ]
271 latents = torch.cat(latents, dim=0).to(device)
272 else:
273 latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
274 else: 264 else:
275 if latents.shape != shape: 265 latents = latents.to(device=device, dtype=dtype)
276 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
277 latents = latents.to(device)
278 266
279 # scale the initial noise by the standard deviation required by the scheduler 267 # scale the initial noise by the standard deviation required by the scheduler
280 latents = latents * self.scheduler.init_noise_sigma 268 latents = latents * self.scheduler.init_noise_sigma
281 269
282 return latents 270 return latents
283 271
284 def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): 272 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
285 init_image = init_image.to(device=device, dtype=dtype) 273 init_image = init_image.to(device=device, dtype=dtype)
286 init_latent_dist = self.vae.encode(init_image).latent_dist 274 init_latent_dist = self.vae.encode(init_image).latent_dist
287 init_latents = init_latent_dist.sample(generator=generator) 275 init_latents = init_latent_dist.sample(generator=generator)
@@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
292 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 280 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
293 ) 281 )
294 else: 282 else:
295 init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) 283 init_latents = torch.cat([init_latents] * batch_size, dim=0)
296 284
297 # add noise to latents using the timesteps 285 # add noise to latents using the timesteps
298 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 286 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
@@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
430 latents = self.prepare_latents_from_image( 418 latents = self.prepare_latents_from_image(
431 image, 419 image,
432 latent_timestep, 420 latent_timestep,
433 batch_size, 421 batch_size * num_images_per_prompt,
434 num_images_per_prompt,
435 text_embeddings.dtype, 422 text_embeddings.dtype,
436 device, 423 device,
437 generator 424 generator
438 ) 425 )
439 else: 426 else:
440 latents = self.prepare_latents( 427 latents = self.prepare_latents(
441 batch_size, 428 batch_size * num_images_per_prompt,
442 num_images_per_prompt,
443 num_channels_latents, 429 num_channels_latents,
444 height, 430 height,
445 width, 431 width,