From fcbc11be99c011ab1003451ef72c95ca587902d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Oct 2022 18:42:27 +0200 Subject: Update --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3e41f86..2656b28 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -4,7 +4,6 @@ from typing import List, Optional, Union import numpy as np import torch -import torch.optim as optim import PIL from diffusers.configuration_utils import FrozenDict @@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) - def get_text_embeddings(self, text_input_ids): - return self.text_encoder(text_input_ids)[0] - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) print(f"Too many tokens: {removed_text}") text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline): uncond_input = self.tokenizer( negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch -- cgit v1.2.3-54-g00ecf