summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-15 18:42:27 +0200
committerVolpeon <git@volpeon.ink>2022-10-15 18:42:27 +0200
commitfcbc11be99c011ab1003451ef72c95ca587902d8 (patch)
tree8a8416e2777874addd05fa2f59896a31f044f1fc /pipelines
parentRemoved aesthetic gradients; training improvements (diff)
downloadtextual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.tar.gz
textual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.tar.bz2
textual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.zip
Update
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py8
1 files changed, 2 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3e41f86..2656b28 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,7 +4,6 @@ from typing import List, Optional, Union
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
7import torch.optim as optim
8import PIL 7import PIL
9 8
10from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
@@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
59 scheduler=scheduler, 58 scheduler=scheduler,
60 ) 59 )
61 60
62 def get_text_embeddings(self, text_input_ids):
63 return self.text_encoder(text_input_ids)[0]
64
65 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 61 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66 r""" 62 r"""
67 Enable sliced attention computation. 63 Enable sliced attention computation.
@@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
199 ) 195 )
200 print(f"Too many tokens: {removed_text}") 196 print(f"Too many tokens: {removed_text}")
201 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
202 text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) 198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
203 199
204 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
211 uncond_input = self.tokenizer( 207 uncond_input = self.tokenizer(
212 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
213 ) 209 )
214 uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) 210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
215 211
216 # For classifier free guidance, we need to do two forward passes. 212 # For classifier free guidance, we need to do two forward passes.
217 # Here we concatenate the unconditional and text embeddings into a single batch 213 # Here we concatenate the unconditional and text embeddings into a single batch