From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 13ea2ac..a0dff54 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker has_nsfw_concept = None - # 11. Convert to PIL - if output_type == "pil": + if output_type == "latent": + image = latents + elif output_type == "pil": + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL image = self.numpy_to_pil(image) + else: + # 9. Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) -- cgit v1.2.3-70-g09d2