diff options
author | Volpeon <git@volpeon.ink> | 2022-12-13 09:40:34 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-13 09:40:34 +0100 |
commit | b33ac00de283fe45edba689990dc96a5de93cd1e (patch) | |
tree | a3106f2e482f9e4b2ab9d9ff49faf0b529278f50 /pipelines | |
parent | Dreambooth: Support loading Textual Inversion embeddings (diff) | |
download | textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.gz textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.bz2 textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.zip |
Add support for resume in Textual Inversion
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 141b9a7..707b639 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -421,25 +421,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
421 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 421 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
422 | 422 | ||
423 | # 7. Denoising loop | 423 | # 7. Denoising loop |
424 | for i, t in enumerate(self.progress_bar(timesteps)): | 424 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
425 | # expand the latents if we are doing classifier free guidance | 425 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
426 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 426 | for i, t in enumerate(timesteps): |
427 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 427 | # expand the latents if we are doing classifier free guidance |
428 | 428 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
429 | # predict the noise residual | 429 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
430 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 430 | |
431 | 431 | # predict the noise residual | |
432 | # perform guidance | 432 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
433 | if do_classifier_free_guidance: | 433 | |
434 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 434 | # perform guidance |
435 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 435 | if do_classifier_free_guidance: |
436 | 436 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
437 | # compute the previous noisy sample x_t -> x_t-1 | 437 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
438 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 438 | |
439 | 439 | # compute the previous noisy sample x_t -> x_t-1 | |
440 | # call the callback, if provided | 440 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
441 | if callback is not None and i % callback_steps == 0: | 441 | |
442 | callback(i, t, latents) | 442 | # call the callback, if provided |
443 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||
444 | progress_bar.update() | ||
445 | if callback is not None and i % callback_steps == 0: | ||
446 | callback(i, t, latents) | ||
443 | 447 | ||
444 | # 8. Post-processing | 448 | # 8. Post-processing |
445 | image = self.decode_latents(latents) | 449 | image = self.decode_latents(latents) |