diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 |
1 files changed, 2 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3e41f86..2656b28 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -4,7 +4,6 @@ from typing import List, Optional, Union | |||
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
7 | import torch.optim as optim | ||
8 | import PIL | 7 | import PIL |
9 | 8 | ||
10 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
@@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
59 | scheduler=scheduler, | 58 | scheduler=scheduler, |
60 | ) | 59 | ) |
61 | 60 | ||
62 | def get_text_embeddings(self, text_input_ids): | ||
63 | return self.text_encoder(text_input_ids)[0] | ||
64 | |||
65 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 61 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
66 | r""" | 62 | r""" |
67 | Enable sliced attention computation. | 63 | Enable sliced attention computation. |
@@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
199 | ) | 195 | ) |
200 | print(f"Too many tokens: {removed_text}") | 196 | print(f"Too many tokens: {removed_text}") |
201 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
202 | text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) | 198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] |
203 | 199 | ||
204 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
205 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
211 | uncond_input = self.tokenizer( | 207 | uncond_input = self.tokenizer( |
212 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | 208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" |
213 | ) | 209 | ) |
214 | uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) | 210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
215 | 211 | ||
216 | # For classifier free guidance, we need to do two forward passes. | 212 | # For classifier free guidance, we need to do two forward passes. |
217 | # Here we concatenate the unconditional and text embeddings into a single batch | 213 | # Here we concatenate the unconditional and text embeddings into a single batch |