diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 36942f0..ba057ba 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -8,11 +8,20 @@ import PIL | |||
| 8 | 8 | ||
| 9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
| 10 | from diffusers.utils import is_accelerate_available | 10 | from diffusers.utils import is_accelerate_available |
| 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import ( |
| 12 | AutoencoderKL, | ||
| 13 | DiffusionPipeline, | ||
| 14 | UNet2DConditionModel, | ||
| 15 | DDIMScheduler, | ||
| 16 | DPMSolverMultistepScheduler, | ||
| 17 | EulerAncestralDiscreteScheduler, | ||
| 18 | EulerDiscreteScheduler, | ||
| 19 | LMSDiscreteScheduler, | ||
| 20 | PNDMScheduler, | ||
| 21 | ) | ||
| 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 13 | from diffusers.utils import logging | 23 | from diffusers.utils import logging |
| 14 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 16 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
| 17 | 26 | ||
| 18 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| @@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 33 | text_encoder: CLIPTextModel, | 42 | text_encoder: CLIPTextModel, |
| 34 | tokenizer: CLIPTokenizer, | 43 | tokenizer: CLIPTokenizer, |
| 35 | unet: UNet2DConditionModel, | 44 | unet: UNet2DConditionModel, |
| 36 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], | 45 | scheduler: Union[ |
| 46 | DDIMScheduler, | ||
| 47 | PNDMScheduler, | ||
| 48 | LMSDiscreteScheduler, | ||
| 49 | EulerDiscreteScheduler, | ||
| 50 | EulerAncestralDiscreteScheduler, | ||
| 51 | DPMSolverMultistepScheduler, | ||
| 52 | ], | ||
| 37 | **kwargs, | 53 | **kwargs, |
| 38 | ): | 54 | ): |
| 39 | super().__init__() | 55 | super().__init__() |
| @@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 252 | latents = 0.18215 * latents | 268 | latents = 0.18215 * latents |
| 253 | 269 | ||
| 254 | # expand init_latents for batch_size | 270 | # expand init_latents for batch_size |
| 255 | latents = torch.cat([latents] * batch_size) | 271 | latents = torch.cat([latents] * batch_size, dim=0) |
| 256 | 272 | ||
| 257 | # get the original timestep using init_timestep | 273 | # get the original timestep using init_timestep |
| 258 | init_timestep = int(num_inference_steps * strength) + offset | 274 | init_timestep = int(num_inference_steps * strength) + offset |
| 259 | init_timestep = min(init_timestep, num_inference_steps) | 275 | init_timestep = min(init_timestep, num_inference_steps) |
| 260 | 276 | ||
| 261 | if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): | 277 | timesteps = self.scheduler.timesteps[-init_timestep] |
| 262 | timesteps = torch.tensor( | 278 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
| 263 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 264 | ) | ||
| 265 | else: | ||
| 266 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 267 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 268 | 279 | ||
| 269 | # add noise to latents using the timesteps | 280 | # add noise to latents using the timesteps |
| 270 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 281 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
