summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
commit49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch)
tree8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /pipelines
parentInference: Add support for embeddings (diff)
downloadtextual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip
Update
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py35
1 files changed, 5 insertions, 30 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8fbe5f9..a198cf6 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -216,7 +216,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
216 216
217 offset = self.scheduler.config.get("steps_offset", 0) 217 offset = self.scheduler.config.get("steps_offset", 0)
218 init_timestep = num_inference_steps + offset 218 init_timestep = num_inference_steps + offset
219 ensure_sigma = not isinstance(latents, PIL.Image.Image)
220 219
221 # get the initial random noise unless the user supplied it 220 # get the initial random noise unless the user supplied it
222 221
@@ -246,13 +245,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
246 init_timestep = int(num_inference_steps * strength) + offset 245 init_timestep = int(num_inference_steps * strength) + offset
247 init_timestep = min(init_timestep, num_inference_steps) 246 init_timestep = min(init_timestep, num_inference_steps)
248 247
249 if isinstance(self.scheduler, LMSDiscreteScheduler): 248 timesteps = self.scheduler.timesteps[-init_timestep]
250 timesteps = torch.tensor( 249 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
251 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
252 )
253 else:
254 timesteps = self.scheduler.timesteps[-init_timestep]
255 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
256 250
257 # add noise to latents using the timesteps 251 # add noise to latents using the timesteps
258 noise = torch.randn(latents.shape, generator=generator, device=self.device) 252 noise = torch.randn(latents.shape, generator=generator, device=self.device)
@@ -263,13 +257,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
263 if latents.device != self.device: 257 if latents.device != self.device:
264 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") 258 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}")
265 259
266 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
267 if ensure_sigma:
268 if isinstance(self.scheduler, LMSDiscreteScheduler):
269 latents = latents * self.scheduler.sigmas[0]
270 elif isinstance(self.scheduler, EulerAScheduler):
271 latents = latents * self.scheduler.sigmas[0]
272
273 t_start = max(num_inference_steps - init_timestep + offset, 0) 260 t_start = max(num_inference_steps - init_timestep + offset, 0)
274 261
275 # Some schedulers like PNDM have timesteps as arrays 262 # Some schedulers like PNDM have timesteps as arrays
@@ -290,19 +277,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
290 extra_step_kwargs["generator"] = generator 277 extra_step_kwargs["generator"] = generator
291 278
292 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 279 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
293 t_index = t_start + i
294
295 # expand the latents if we are doing classifier free guidance 280 # expand the latents if we are doing classifier free guidance
296 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 281 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
297 282 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
298 if isinstance(self.scheduler, LMSDiscreteScheduler):
299 sigma = self.scheduler.sigmas[t_index]
300 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
301 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
302 283
303 noise_pred = None 284 noise_pred = None
304 if isinstance(self.scheduler, EulerAScheduler): 285 if isinstance(self.scheduler, EulerAScheduler):
305 sigma = self.scheduler.sigmas[t].reshape(1) 286 sigma = t.reshape(1)
306 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 287 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
307 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 288 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
308 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) 289 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
@@ -316,13 +297,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
316 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 297 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
317 298
318 # compute the previous noisy sample x_t -> x_t-1 299 # compute the previous noisy sample x_t -> x_t-1
319 if isinstance(self.scheduler, LMSDiscreteScheduler): 300 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
320 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
321 elif isinstance(self.scheduler, EulerAScheduler):
322 latents = self.scheduler.step(noise_pred, t_index, t_index + 1,
323 latents, **extra_step_kwargs).prev_sample
324 else:
325 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
326 301
327 # scale and decode the image latents with vae 302 # scale and decode the image latents with vae
328 latents = 1 / 0.18215 * latents 303 latents = 1 / 0.18215 * latents