summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 4c793a8..a8ecedf 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator) 186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents 187 latents = 0.18215 * latents
188
189 # expand init_latents for batch_size
188 latents = torch.cat([latents] * batch_size) 190 latents = torch.cat([latents] * batch_size)
189 191
190 # get the original timestep using init_timestep 192 # get the original timestep using init_timestep
@@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 timesteps = torch.tensor( 197 timesteps = torch.tensor(
196 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device 198 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
197 ) 199 )
198 elif isinstance(self.scheduler, EulerAScheduler):
199 timesteps = self.scheduler.timesteps[-init_timestep]
200 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
201 else: 200 else:
202 timesteps = self.scheduler.timesteps[-init_timestep] 201 timesteps = self.scheduler.timesteps[-init_timestep]
203 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) 202 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
@@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
273 if isinstance(self.scheduler, LMSDiscreteScheduler): 272 if isinstance(self.scheduler, LMSDiscreteScheduler):
274 latents = latents * self.scheduler.sigmas[0] 273 latents = latents * self.scheduler.sigmas[0]
275 elif isinstance(self.scheduler, EulerAScheduler): 274 elif isinstance(self.scheduler, EulerAScheduler):
276 sigma = self.scheduler.timesteps[0] 275 latents = latents * self.scheduler.sigmas[0]
277 latents = latents * sigma
278 276
279 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
280 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 299
302 noise_pred = None 300 noise_pred = None
303 if isinstance(self.scheduler, EulerAScheduler): 301 if isinstance(self.scheduler, EulerAScheduler):
304 sigma = t.reshape(1) 302 sigma = self.scheduler.sigmas[t].reshape(1)
305 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 303 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
306 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
307 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 304 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
308 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) 305 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
309 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
310 else: 306 else:
311 # predict the noise residual 307 # predict the noise residual
312 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 308 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
320 if isinstance(self.scheduler, LMSDiscreteScheduler): 316 if isinstance(self.scheduler, LMSDiscreteScheduler):
321 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample 317 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
322 elif isinstance(self.scheduler, EulerAScheduler): 318 elif isinstance(self.scheduler, EulerAScheduler):
323 if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error 319 latents = self.scheduler.step(noise_pred, t_index, t_index + 1,
324 t_prev = self.scheduler.timesteps[t_index+1] 320 latents, **extra_step_kwargs).prev_sample
325 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
326 else: 321 else:
327 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 322 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
328 323