summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-15 13:11:11 +0200
committerVolpeon <git@volpeon.ink>2023-04-15 13:11:11 +0200
commit99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch)
tree717a4099e9ebfedec702060fed5ed12aaceb0094 /pipelines
parentAdded cycle LR decay (diff)
downloadtextual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz
textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2
textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip
TI via LoRA
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 13ea2ac..a0dff54 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
591 if callback is not None and i % callback_steps == 0: 591 if callback is not None and i % callback_steps == 0:
592 callback(i, t, latents) 592 callback(i, t, latents)
593 593
594 # 9. Post-processing
595 image = self.decode_latents(latents)
596
597 # 10. Run safety checker
598 has_nsfw_concept = None 594 has_nsfw_concept = None
599 595
600 # 11. Convert to PIL 596 if output_type == "latent":
601 if output_type == "pil": 597 image = latents
598 elif output_type == "pil":
599 # 9. Post-processing
600 image = self.decode_latents(latents)
601
602 # 10. Convert to PIL
602 image = self.numpy_to_pil(image) 603 image = self.numpy_to_pil(image)
604 else:
605 # 9. Post-processing
606 image = self.decode_latents(latents)
607
608 # Offload last model to CPU
609 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
610 self.final_offload_hook.offload()
603 611
604 if not return_dict: 612 if not return_dict:
605 return (image, has_nsfw_concept) 613 return (image, has_nsfw_concept)