summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py2
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py17
-rw-r--r--schedulers/scheduling_euler_a.py59
3 files changed, 33 insertions, 45 deletions
diff --git a/infer.py b/infer.py
index b440cb6..c40335c 100644
--- a/infer.py
+++ b/infer.py
@@ -176,7 +176,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16):
176 ) 176 )
177 else: 177 else:
178 scheduler = EulerAScheduler( 178 scheduler = EulerAScheduler(
179 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 179 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
180 ) 180 )
181 181
182 pipeline = VlpnStableDiffusion( 182 pipeline = VlpnStableDiffusion(
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 4c793a8..a8ecedf 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator) 186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents 187 latents = 0.18215 * latents
188
189 # expand init_latents for batch_size
188 latents = torch.cat([latents] * batch_size) 190 latents = torch.cat([latents] * batch_size)
189 191
190 # get the original timestep using init_timestep 192 # get the original timestep using init_timestep
@@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 timesteps = torch.tensor( 197 timesteps = torch.tensor(
196 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device 198 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
197 ) 199 )
198 elif isinstance(self.scheduler, EulerAScheduler):
199 timesteps = self.scheduler.timesteps[-init_timestep]
200 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
201 else: 200 else:
202 timesteps = self.scheduler.timesteps[-init_timestep] 201 timesteps = self.scheduler.timesteps[-init_timestep]
203 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) 202 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
@@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
273 if isinstance(self.scheduler, LMSDiscreteScheduler): 272 if isinstance(self.scheduler, LMSDiscreteScheduler):
274 latents = latents * self.scheduler.sigmas[0] 273 latents = latents * self.scheduler.sigmas[0]
275 elif isinstance(self.scheduler, EulerAScheduler): 274 elif isinstance(self.scheduler, EulerAScheduler):
276 sigma = self.scheduler.timesteps[0] 275 latents = latents * self.scheduler.sigmas[0]
277 latents = latents * sigma
278 276
279 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
280 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 299
302 noise_pred = None 300 noise_pred = None
303 if isinstance(self.scheduler, EulerAScheduler): 301 if isinstance(self.scheduler, EulerAScheduler):
304 sigma = t.reshape(1) 302 sigma = self.scheduler.sigmas[t].reshape(1)
305 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 303 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
306 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
307 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 304 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
308 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) 305 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
309 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
310 else: 306 else:
311 # predict the noise residual 307 # predict the noise residual
312 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 308 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
320 if isinstance(self.scheduler, LMSDiscreteScheduler): 316 if isinstance(self.scheduler, LMSDiscreteScheduler):
321 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample 317 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
322 elif isinstance(self.scheduler, EulerAScheduler): 318 elif isinstance(self.scheduler, EulerAScheduler):
323 if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error 319 latents = self.scheduler.step(noise_pred, t_index, t_index + 1,
324 t_prev = self.scheduler.timesteps[t_index+1] 320 latents, **extra_step_kwargs).prev_sample
325 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
326 else: 321 else:
327 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 322 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
328 323
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 9fbedaa..1b1c9cf 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -1,7 +1,3 @@
1
2
3import math
4import warnings
5from typing import Optional, Tuple, Union 1from typing import Optional, Tuple, Union
6 2
7import numpy as np 3import numpy as np
@@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
157 beta_end: float = 0.02, 153 beta_end: float = 0.02,
158 beta_schedule: str = "linear", 154 beta_schedule: str = "linear",
159 trained_betas: Optional[np.ndarray] = None, 155 trained_betas: Optional[np.ndarray] = None,
160 clip_sample: bool = True,
161 set_alpha_to_one: bool = True,
162 steps_offset: int = 0,
163 ): 156 ):
164 if trained_betas is not None: 157 if trained_betas is not None:
165 self.betas = torch.from_numpy(trained_betas) 158 self.betas = torch.from_numpy(trained_betas)
@@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
177 self.alphas = 1.0 - self.betas 170 self.alphas = 1.0 - self.betas
178 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 171 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
179 172
180 # At every step in ddim, we are looking into the previous alphas_cumprod
181 # For the final step, there is no previous alphas_cumprod because we are already at 0
182 # `set_alpha_to_one` decides whether we set this parameter simply to one or
183 # whether we use the final alpha of the "non-previous" one.
184 self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
185
186 # setable values 173 # setable values
187 self.num_inference_steps = None 174 self.num_inference_steps = None
188 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 175 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
@@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
199 the number of diffusion steps used when generating samples with a pre-trained model. 186 the number of diffusion steps used when generating samples with a pre-trained model.
200 """ 187 """
201 188
202 # offset = self.config.steps_offset
203
204 # if "offset" in kwargs:
205 # warnings.warn(
206 # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
207 # " Please pass `steps_offset` to `__init__` instead.",
208 # DeprecationWarning,
209 # )
210
211 # offset = kwargs["offset"]
212
213 self.num_inference_steps = num_inference_steps 189 self.num_inference_steps = num_inference_steps
214 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 190 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
215 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) 191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
216 self.timesteps = self.sigmas 192 self.timesteps = np.arange(0, self.num_inference_steps)
217 193
218 def add_noise_to_input( 194 def add_noise_to_input(
219 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None 195 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
@@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
239 def step( 215 def step(
240 self, 216 self,
241 model_output: torch.FloatTensor, 217 model_output: torch.FloatTensor,
242 timestep: torch.IntTensor, 218 timestep: int,
243 timestep_prev: torch.IntTensor, 219 timestep_prev: int,
244 sample: torch.FloatTensor, 220 sample: torch.FloatTensor,
245 generator: None, 221 generator: None,
246 # ,sigma_hat: float, 222 # ,sigma_hat: float,
@@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
266 returning a tuple, the first element is the sample tensor. 242 returning a tuple, the first element is the sample tensor.
267 243
268 """ 244 """
245 s = self.sigmas[timestep]
246 s_prev = self.sigmas[timestep_prev]
269 latents = sample 247 latents = sample
270 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) 248
271 d = to_d(latents, timestep, model_output) 249 sigma_down, sigma_up = get_ancestral_step(s, s_prev)
272 dt = sigma_down - timestep 250 d = to_d(latents, s, model_output)
251 dt = sigma_down - s
273 latents = latents + d * dt 252 latents = latents + d * dt
274 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 253 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,
275 generator=generator) * sigma_up 254 generator=generator) * sigma_up
255
276 return SchedulerOutput(prev_sample=latents) 256 return SchedulerOutput(prev_sample=latents)
277 257
278 def step_correct( 258 def step_correct(
@@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
311 291
312 return SchedulerOutput(prev_sample=sample_prev) 292 return SchedulerOutput(prev_sample=sample_prev)
313 293
314 def add_noise(self, original_samples, noise, timesteps): 294 def add_noise(
315 raise NotImplementedError() 295 self,
296 original_samples: torch.FloatTensor,
297 noise: torch.FloatTensor,
298 timesteps: torch.IntTensor,
299 ) -> torch.FloatTensor:
300 sigmas = self.sigmas.to(original_samples.device)
301 timesteps = timesteps.to(original_samples.device)
302
303 sigma = sigmas[timesteps].flatten()
304 while len(sigma.shape) < len(original_samples.shape):
305 sigma = sigma.unsqueeze(-1)
306
307 noisy_samples = original_samples + noise * sigma
308 return noisy_samples