diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 71 |
1 files changed, 36 insertions, 35 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a8ecedf..b4c85e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
178 | # set timesteps | 178 | # set timesteps |
179 | self.scheduler.set_timesteps(num_inference_steps) | 179 | self.scheduler.set_timesteps(num_inference_steps) |
180 | 180 | ||
181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
182 | |||
183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
184 | latents = preprocess(latents, width, height) | ||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
186 | latents = latent_dist.sample(generator=generator) | ||
187 | latents = 0.18215 * latents | ||
188 | |||
189 | # expand init_latents for batch_size | ||
190 | latents = torch.cat([latents] * batch_size) | ||
191 | |||
192 | # get the original timestep using init_timestep | ||
193 | init_timestep = int(num_inference_steps * strength) + offset | ||
194 | init_timestep = min(init_timestep, num_inference_steps) | ||
195 | |||
196 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
197 | timesteps = torch.tensor( | ||
198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
199 | ) | ||
200 | else: | ||
201 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
203 | |||
204 | # add noise to latents using the timesteps | ||
205 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
206 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
207 | else: | ||
208 | init_timestep = num_inference_steps + offset | ||
209 | |||
210 | # get prompt text embeddings | 181 | # get prompt text embeddings |
211 | text_inputs = self.tokenizer( | 182 | text_inputs = self.tokenizer( |
212 | prompt, | 183 | prompt, |
@@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
243 | # to avoid doing two forward passes | 214 | # to avoid doing two forward passes |
244 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 215 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
245 | 216 | ||
217 | offset = self.scheduler.config.get("steps_offset", 0) | ||
218 | init_timestep = num_inference_steps + offset | ||
219 | ensure_sigma = not isinstance(latents, PIL.Image.Image) | ||
220 | |||
246 | # get the initial random noise unless the user supplied it | 221 | # get the initial random noise unless the user supplied it |
247 | 222 | ||
248 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
@@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
257 | device=latents_device, | 232 | device=latents_device, |
258 | dtype=text_embeddings.dtype, | 233 | dtype=text_embeddings.dtype, |
259 | ) | 234 | ) |
235 | elif isinstance(latents, PIL.Image.Image): | ||
236 | latents = preprocess(latents, width, height) | ||
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
238 | latents = latent_dist.sample(generator=generator) | ||
239 | latents = 0.18215 * latents | ||
240 | |||
241 | # expand init_latents for batch_size | ||
242 | latents = torch.cat([latents] * batch_size) | ||
243 | |||
244 | # get the original timestep using init_timestep | ||
245 | init_timestep = int(num_inference_steps * strength) + offset | ||
246 | init_timestep = min(init_timestep, num_inference_steps) | ||
247 | |||
248 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
249 | timesteps = torch.tensor( | ||
250 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
251 | ) | ||
252 | else: | ||
253 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
254 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
255 | |||
256 | # add noise to latents using the timesteps | ||
257 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
258 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
260 | else: | 259 | else: |
261 | if latents.shape != latents_shape: | 260 | if latents.shape != latents_shape: |
262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
263 | latents = latents.to(self.device) | 262 | latents = latents.to(self.device) |
264 | 263 | ||
264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
265 | if ensure_sigma: | ||
266 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
267 | latents = latents * self.scheduler.sigmas[0] | ||
268 | elif isinstance(self.scheduler, EulerAScheduler): | ||
269 | latents = latents * self.scheduler.sigmas[0] | ||
270 | |||
265 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 271 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
266 | 272 | ||
267 | # Some schedulers like PNDM have timesteps as arrays | 273 | # Some schedulers like PNDM have timesteps as arrays |
268 | # It's more optimzed to move all timesteps to correct device beforehand | 274 | # It's more optimzed to move all timesteps to correct device beforehand |
269 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 275 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
270 | 276 | ||
271 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
273 | latents = latents * self.scheduler.sigmas[0] | ||
274 | elif isinstance(self.scheduler, EulerAScheduler): | ||
275 | latents = latents * self.scheduler.sigmas[0] | ||
276 | |||
277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 279 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
@@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
292 | 292 | ||
293 | # expand the latents if we are doing classifier free guidance | 293 | # expand the latents if we are doing classifier free guidance |
294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 294 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
295 | |||
295 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 296 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
296 | sigma = self.scheduler.sigmas[t_index] | 297 | sigma = self.scheduler.sigmas[t_index] |
297 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 298 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |