From fcbc11be99c011ab1003451ef72c95ca587902d8 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 15 Oct 2022 18:42:27 +0200
Subject: Update

---
 pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

(limited to 'pipelines/stable_diffusion')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3e41f86..2656b28 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,7 +4,6 @@ from typing import List, Optional, Union
 
 import numpy as np
 import torch
-import torch.optim as optim
 import PIL
 
 from diffusers.configuration_utils import FrozenDict
@@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
             scheduler=scheduler,
         )
 
-    def get_text_embeddings(self, text_input_ids):
-        return self.text_encoder(text_input_ids)[0]
-
     def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
         r"""
         Enable sliced attention computation.
@@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
             )
             print(f"Too many tokens: {removed_text}")
             text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
-        text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device))
+        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
 
         # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
             uncond_input = self.tokenizer(
                 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
             )
-            uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device))
+            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
 
             # For classifier free guidance, we need to do two forward passes.
             # Here we concatenate the unconditional and text embeddings into a single batch
-- 
cgit v1.2.3-70-g09d2