summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py33
1 files changed, 22 insertions, 11 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 36942f0..ba057ba 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -8,11 +8,20 @@ import PIL
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers.utils import is_accelerate_available 10from diffusers.utils import is_accelerate_available
11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import (
12 AutoencoderKL,
13 DiffusionPipeline,
14 UNet2DConditionModel,
15 DDIMScheduler,
16 DPMSolverMultistepScheduler,
17 EulerAncestralDiscreteScheduler,
18 EulerDiscreteScheduler,
19 LMSDiscreteScheduler,
20 PNDMScheduler,
21)
12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
13from diffusers.utils import logging 23from diffusers.utils import logging
14from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
16from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
17 26
18logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
33 text_encoder: CLIPTextModel, 42 text_encoder: CLIPTextModel,
34 tokenizer: CLIPTokenizer, 43 tokenizer: CLIPTokenizer,
35 unet: UNet2DConditionModel, 44 unet: UNet2DConditionModel,
36 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], 45 scheduler: Union[
46 DDIMScheduler,
47 PNDMScheduler,
48 LMSDiscreteScheduler,
49 EulerDiscreteScheduler,
50 EulerAncestralDiscreteScheduler,
51 DPMSolverMultistepScheduler,
52 ],
37 **kwargs, 53 **kwargs,
38 ): 54 ):
39 super().__init__() 55 super().__init__()
@@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
252 latents = 0.18215 * latents 268 latents = 0.18215 * latents
253 269
254 # expand init_latents for batch_size 270 # expand init_latents for batch_size
255 latents = torch.cat([latents] * batch_size) 271 latents = torch.cat([latents] * batch_size, dim=0)
256 272
257 # get the original timestep using init_timestep 273 # get the original timestep using init_timestep
258 init_timestep = int(num_inference_steps * strength) + offset 274 init_timestep = int(num_inference_steps * strength) + offset
259 init_timestep = min(init_timestep, num_inference_steps) 275 init_timestep = min(init_timestep, num_inference_steps)
260 276
261 if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): 277 timesteps = self.scheduler.timesteps[-init_timestep]
262 timesteps = torch.tensor( 278 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
263 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
264 )
265 else:
266 timesteps = self.scheduler.timesteps[-init_timestep]
267 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
268 279
269 # add noise to latents using the timesteps 280 # add noise to latents using the timesteps
270 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) 281 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)