summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py42
1 files changed, 23 insertions, 19 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 141b9a7..707b639 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -421,25 +421,29 @@ class VlpnStableDiffusion(DiffusionPipeline):
421 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 421 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
422 422
423 # 7. Denoising loop 423 # 7. Denoising loop
424 for i, t in enumerate(self.progress_bar(timesteps)): 424 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
425 # expand the latents if we are doing classifier free guidance 425 with self.progress_bar(total=num_inference_steps) as progress_bar:
426 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 426 for i, t in enumerate(timesteps):
427 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 427 # expand the latents if we are doing classifier free guidance
428 428 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
429 # predict the noise residual 429 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
430 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 430
431 431 # predict the noise residual
432 # perform guidance 432 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
433 if do_classifier_free_guidance: 433
434 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 434 # perform guidance
435 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 435 if do_classifier_free_guidance:
436 436 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
437 # compute the previous noisy sample x_t -> x_t-1 437 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
438 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 438
439 439 # compute the previous noisy sample x_t -> x_t-1
440 # call the callback, if provided 440 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
441 if callback is not None and i % callback_steps == 0: 441
442 callback(i, t, latents) 442 # call the callback, if provided
443 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
444 progress_bar.update()
445 if callback is not None and i % callback_steps == 0:
446 callback(i, t, latents)
443 447
444 # 8. Post-processing 448 # 8. Post-processing
445 image = self.decode_latents(latents) 449 image = self.decode_latents(latents)