diff options
-rw-r--r-- | infer.py | 2 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 17 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 59 |
3 files changed, 33 insertions, 45 deletions
@@ -176,7 +176,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
176 | ) | 176 | ) |
177 | else: | 177 | else: |
178 | scheduler = EulerAScheduler( | 178 | scheduler = EulerAScheduler( |
179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
180 | ) | 180 | ) |
181 | 181 | ||
182 | pipeline = VlpnStableDiffusion( | 182 | pipeline = VlpnStableDiffusion( |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
186 | latents = latent_dist.sample(generator=generator) | 186 | latents = latent_dist.sample(generator=generator) |
187 | latents = 0.18215 * latents | 187 | latents = 0.18215 * latents |
188 | |||
189 | # expand init_latents for batch_size | ||
188 | latents = torch.cat([latents] * batch_size) | 190 | latents = torch.cat([latents] * batch_size) |
189 | 191 | ||
190 | # get the original timestep using init_timestep | 192 | # get the original timestep using init_timestep |
@@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
195 | timesteps = torch.tensor( | 197 | timesteps = torch.tensor( |
196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device |
197 | ) | 199 | ) |
198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
201 | else: | 200 | else: |
202 | timesteps = self.scheduler.timesteps[-init_timestep] | 201 | timesteps = self.scheduler.timesteps[-init_timestep] |
203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) |
@@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
273 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
274 | latents = latents * self.scheduler.sigmas[0] | 273 | latents = latents * self.scheduler.sigmas[0] |
275 | elif isinstance(self.scheduler, EulerAScheduler): | 274 | elif isinstance(self.scheduler, EulerAScheduler): |
276 | sigma = self.scheduler.timesteps[0] | 275 | latents = latents * self.scheduler.sigmas[0] |
277 | latents = latents * sigma | ||
278 | 276 | ||
279 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
280 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
@@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | 299 | ||
302 | noise_pred = None | 300 | noise_pred = None |
303 | if isinstance(self.scheduler, EulerAScheduler): | 301 | if isinstance(self.scheduler, EulerAScheduler): |
304 | sigma = t.reshape(1) | 302 | sigma = self.scheduler.sigmas[t].reshape(1) |
305 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 303 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
306 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 304 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 305 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
309 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
310 | else: | 306 | else: |
311 | # predict the noise residual | 307 | # predict the noise residual |
312 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 308 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
@@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
320 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 316 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | 317 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
322 | elif isinstance(self.scheduler, EulerAScheduler): | 318 | elif isinstance(self.scheduler, EulerAScheduler): |
323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 319 | latents = self.scheduler.step(noise_pred, t_index, t_index + 1, |
324 | t_prev = self.scheduler.timesteps[t_index+1] | 320 | latents, **extra_step_kwargs).prev_sample |
325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
326 | else: | 321 | else: |
327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
328 | 323 | ||
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -1,7 +1,3 @@ | |||
1 | |||
2 | |||
3 | import math | ||
4 | import warnings | ||
5 | from typing import Optional, Tuple, Union | 1 | from typing import Optional, Tuple, Union |
6 | 2 | ||
7 | import numpy as np | 3 | import numpy as np |
@@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
157 | beta_end: float = 0.02, | 153 | beta_end: float = 0.02, |
158 | beta_schedule: str = "linear", | 154 | beta_schedule: str = "linear", |
159 | trained_betas: Optional[np.ndarray] = None, | 155 | trained_betas: Optional[np.ndarray] = None, |
160 | clip_sample: bool = True, | ||
161 | set_alpha_to_one: bool = True, | ||
162 | steps_offset: int = 0, | ||
163 | ): | 156 | ): |
164 | if trained_betas is not None: | 157 | if trained_betas is not None: |
165 | self.betas = torch.from_numpy(trained_betas) | 158 | self.betas = torch.from_numpy(trained_betas) |
@@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
177 | self.alphas = 1.0 - self.betas | 170 | self.alphas = 1.0 - self.betas |
178 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 171 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
179 | 172 | ||
180 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
181 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
182 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
183 | # whether we use the final alpha of the "non-previous" one. | ||
184 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
185 | |||
186 | # setable values | 173 | # setable values |
187 | self.num_inference_steps = None | 174 | self.num_inference_steps = None |
188 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 175 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] |
@@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
199 | the number of diffusion steps used when generating samples with a pre-trained model. | 186 | the number of diffusion steps used when generating samples with a pre-trained model. |
200 | """ | 187 | """ |
201 | 188 | ||
202 | # offset = self.config.steps_offset | ||
203 | |||
204 | # if "offset" in kwargs: | ||
205 | # warnings.warn( | ||
206 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
207 | # " Please pass `steps_offset` to `__init__` instead.", | ||
208 | # DeprecationWarning, | ||
209 | # ) | ||
210 | |||
211 | # offset = kwargs["offset"] | ||
212 | |||
213 | self.num_inference_steps = num_inference_steps | 189 | self.num_inference_steps = num_inference_steps |
214 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 190 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
215 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
216 | self.timesteps = self.sigmas | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
217 | 193 | ||
218 | def add_noise_to_input( | 194 | def add_noise_to_input( |
219 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | 195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None |
@@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
239 | def step( | 215 | def step( |
240 | self, | 216 | self, |
241 | model_output: torch.FloatTensor, | 217 | model_output: torch.FloatTensor, |
242 | timestep: torch.IntTensor, | 218 | timestep: int, |
243 | timestep_prev: torch.IntTensor, | 219 | timestep_prev: int, |
244 | sample: torch.FloatTensor, | 220 | sample: torch.FloatTensor, |
245 | generator: None, | 221 | generator: None, |
246 | # ,sigma_hat: float, | 222 | # ,sigma_hat: float, |
@@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
266 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. |
267 | 243 | ||
268 | """ | 244 | """ |
245 | s = self.sigmas[timestep] | ||
246 | s_prev = self.sigmas[timestep_prev] | ||
269 | latents = sample | 247 | latents = sample |
270 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 248 | |
271 | d = to_d(latents, timestep, model_output) | 249 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) |
272 | dt = sigma_down - timestep | 250 | d = to_d(latents, s, model_output) |
251 | dt = sigma_down - s | ||
273 | latents = latents + d * dt | 252 | latents = latents + d * dt |
274 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 253 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |
275 | generator=generator) * sigma_up | 254 | generator=generator) * sigma_up |
255 | |||
276 | return SchedulerOutput(prev_sample=latents) | 256 | return SchedulerOutput(prev_sample=latents) |
277 | 257 | ||
278 | def step_correct( | 258 | def step_correct( |
@@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
311 | 291 | ||
312 | return SchedulerOutput(prev_sample=sample_prev) | 292 | return SchedulerOutput(prev_sample=sample_prev) |
313 | 293 | ||
314 | def add_noise(self, original_samples, noise, timesteps): | 294 | def add_noise( |
315 | raise NotImplementedError() | 295 | self, |
296 | original_samples: torch.FloatTensor, | ||
297 | noise: torch.FloatTensor, | ||
298 | timesteps: torch.IntTensor, | ||
299 | ) -> torch.FloatTensor: | ||
300 | sigmas = self.sigmas.to(original_samples.device) | ||
301 | timesteps = timesteps.to(original_samples.device) | ||
302 | |||
303 | sigma = sigmas[timesteps].flatten() | ||
304 | while len(sigma.shape) < len(original_samples.shape): | ||
305 | sigma = sigma.unsqueeze(-1) | ||
306 | |||
307 | noisy_samples = original_samples + noise * sigma | ||
308 | return noisy_samples | ||