diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 71 |
1 files changed, 36 insertions, 35 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a8ecedf..b4c85e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 178 | # set timesteps | 178 | # set timesteps |
| 179 | self.scheduler.set_timesteps(num_inference_steps) | 179 | self.scheduler.set_timesteps(num_inference_steps) |
| 180 | 180 | ||
| 181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 182 | |||
| 183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
| 184 | latents = preprocess(latents, width, height) | ||
| 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
| 186 | latents = latent_dist.sample(generator=generator) | ||
| 187 | latents = 0.18215 * latents | ||
| 188 | |||
| 189 | # expand init_latents for batch_size | ||
| 190 | latents = torch.cat([latents] * batch_size) | ||
| 191 | |||
| 192 | # get the original timestep using init_timestep | ||
| 193 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 194 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 195 | |||
| 196 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 197 | timesteps = torch.tensor( | ||
| 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 199 | ) | ||
| 200 | else: | ||
| 201 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 203 | |||
| 204 | # add noise to latents using the timesteps | ||
| 205 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
| 206 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
| 207 | else: | ||
| 208 | init_timestep = num_inference_steps + offset | ||
| 209 | |||
| 210 | # get prompt text embeddings | 181 | # get prompt text embeddings |
| 211 | text_inputs = self.tokenizer( | 182 | text_inputs = self.tokenizer( |
| 212 | prompt, | 183 | prompt, |
| @@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 243 | # to avoid doing two forward passes | 214 | # to avoid doing two forward passes |
| 244 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 215 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| 245 | 216 | ||
| 217 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 218 | init_timestep = num_inference_steps + offset | ||
| 219 | ensure_sigma = not isinstance(latents, PIL.Image.Image) | ||
| 220 | |||
| 246 | # get the initial random noise unless the user supplied it | 221 | # get the initial random noise unless the user supplied it |
| 247 | 222 | ||
| 248 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
| @@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 257 | device=latents_device, | 232 | device=latents_device, |
| 258 | dtype=text_embeddings.dtype, | 233 | dtype=text_embeddings.dtype, |
| 259 | ) | 234 | ) |
| 235 | elif isinstance(latents, PIL.Image.Image): | ||
| 236 | latents = preprocess(latents, width, height) | ||
| 237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
| 238 | latents = latent_dist.sample(generator=generator) | ||
| 239 | latents = 0.18215 * latents | ||
| 240 | |||
| 241 | # expand init_latents for batch_size | ||
| 242 | latents = torch.cat([latents] * batch_size) | ||
| 243 | |||
| 244 | # get the original timestep using init_timestep | ||
| 245 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 246 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 247 | |||
| 248 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 249 | timesteps = torch.tensor( | ||
| 250 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 251 | ) | ||
| 252 | else: | ||
| 253 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 254 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 255 | |||
| 256 | # add noise to latents using the timesteps | ||
| 257 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
| 258 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
| 260 | else: | 259 | else: |
| 261 | if latents.shape != latents_shape: | 260 | if latents.shape != latents_shape: |
| 262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
| 263 | latents = latents.to(self.device) | 262 | latents = latents.to(self.device) |
| 264 | 263 | ||
| 264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 265 | if ensure_sigma: | ||
| 266 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 267 | latents = latents * self.scheduler.sigmas[0] | ||
| 268 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 269 | latents = latents * self.scheduler.sigmas[0] | ||
| 270 | |||
| 265 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 271 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
| 266 | 272 | ||
| 267 | # Some schedulers like PNDM have timesteps as arrays | 273 | # Some schedulers like PNDM have timesteps as arrays |
| 268 | # It's more optimzed to move all timesteps to correct device beforehand | 274 | # It's more optimzed to move all timesteps to correct device beforehand |
| 269 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 275 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
| 270 | 276 | ||
| 271 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 273 | latents = latents * self.scheduler.sigmas[0] | ||
| 274 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 275 | latents = latents * self.scheduler.sigmas[0] | ||
| 276 | |||
| 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| @@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 292 | 292 | ||
| 293 | # expand the latents if we are doing classifier free guidance | 293 | # expand the latents if we are doing classifier free guidance |
| 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 295 | |||
| 295 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 296 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 296 | sigma = self.scheduler.sigmas[t_index] | 297 | sigma = self.scheduler.sigmas[t_index] |
| 297 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 298 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
