summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py71
1 files changed, 36 insertions, 35 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a8ecedf..b4c85e9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
178 # set timesteps 178 # set timesteps
179 self.scheduler.set_timesteps(num_inference_steps) 179 self.scheduler.set_timesteps(num_inference_steps)
180 180
181 offset = self.scheduler.config.get("steps_offset", 0)
182
183 if latents is not None and isinstance(latents, PIL.Image.Image):
184 latents = preprocess(latents, width, height)
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents
188
189 # expand init_latents for batch_size
190 latents = torch.cat([latents] * batch_size)
191
192 # get the original timestep using init_timestep
193 init_timestep = int(num_inference_steps * strength) + offset
194 init_timestep = min(init_timestep, num_inference_steps)
195
196 if isinstance(self.scheduler, LMSDiscreteScheduler):
197 timesteps = torch.tensor(
198 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
199 )
200 else:
201 timesteps = self.scheduler.timesteps[-init_timestep]
202 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
203
204 # add noise to latents using the timesteps
205 noise = torch.randn(latents.shape, generator=generator, device=self.device)
206 latents = self.scheduler.add_noise(latents, noise, timesteps)
207 else:
208 init_timestep = num_inference_steps + offset
209
210 # get prompt text embeddings 181 # get prompt text embeddings
211 text_inputs = self.tokenizer( 182 text_inputs = self.tokenizer(
212 prompt, 183 prompt,
@@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
243 # to avoid doing two forward passes 214 # to avoid doing two forward passes
244 text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 215 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
245 216
217 offset = self.scheduler.config.get("steps_offset", 0)
218 init_timestep = num_inference_steps + offset
219 ensure_sigma = not isinstance(latents, PIL.Image.Image)
220
246 # get the initial random noise unless the user supplied it 221 # get the initial random noise unless the user supplied it
247 222
248 # Unlike in other pipelines, latents need to be generated in the target device 223 # Unlike in other pipelines, latents need to be generated in the target device
@@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline):
257 device=latents_device, 232 device=latents_device,
258 dtype=text_embeddings.dtype, 233 dtype=text_embeddings.dtype,
259 ) 234 )
235 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
238 latents = latent_dist.sample(generator=generator)
239 latents = 0.18215 * latents
240
241 # expand init_latents for batch_size
242 latents = torch.cat([latents] * batch_size)
243
244 # get the original timestep using init_timestep
245 init_timestep = int(num_inference_steps * strength) + offset
246 init_timestep = min(init_timestep, num_inference_steps)
247
248 if isinstance(self.scheduler, LMSDiscreteScheduler):
249 timesteps = torch.tensor(
250 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
251 )
252 else:
253 timesteps = self.scheduler.timesteps[-init_timestep]
254 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
255
256 # add noise to latents using the timesteps
257 noise = torch.randn(latents.shape, generator=generator, device=self.device)
258 latents = self.scheduler.add_noise(latents, noise, timesteps)
260 else: 259 else:
261 if latents.shape != latents_shape: 260 if latents.shape != latents_shape:
262 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 261 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
263 latents = latents.to(self.device) 262 latents = latents.to(self.device)
264 263
264 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265 if ensure_sigma:
266 if isinstance(self.scheduler, LMSDiscreteScheduler):
267 latents = latents * self.scheduler.sigmas[0]
268 elif isinstance(self.scheduler, EulerAScheduler):
269 latents = latents * self.scheduler.sigmas[0]
270
265 t_start = max(num_inference_steps - init_timestep + offset, 0) 271 t_start = max(num_inference_steps - init_timestep + offset, 0)
266 272
267 # Some schedulers like PNDM have timesteps as arrays 273 # Some schedulers like PNDM have timesteps as arrays
268 # It's more optimzed to move all timesteps to correct device beforehand 274 # It's more optimzed to move all timesteps to correct device beforehand
269 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) 275 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
270 276
271 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
272 if isinstance(self.scheduler, LMSDiscreteScheduler):
273 latents = latents * self.scheduler.sigmas[0]
274 elif isinstance(self.scheduler, EulerAScheduler):
275 latents = latents * self.scheduler.sigmas[0]
276
277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
279 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 279 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
292 292
293 # expand the latents if we are doing classifier free guidance 293 # expand the latents if we are doing classifier free guidance
294 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 294 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
295
295 if isinstance(self.scheduler, LMSDiscreteScheduler): 296 if isinstance(self.scheduler, LMSDiscreteScheduler):
296 sigma = self.scheduler.sigmas[t_index] 297 sigma = self.scheduler.sigmas[t_index]
297 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 298 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS