summaryrefslogtreecommitdiffstats
path: root/schedulers/scheduling_euler_a.py
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers/scheduling_euler_a.py')
-rw-r--r--schedulers/scheduling_euler_a.py286
1 files changed, 0 insertions, 286 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
deleted file mode 100644
index c097a8a..0000000
--- a/schedulers/scheduling_euler_a.py
+++ /dev/null
@@ -1,286 +0,0 @@
1from typing import Optional, Tuple, Union
2
3import numpy as np
4import torch
5
6from diffusers.configuration_utils import ConfigMixin, register_to_config
7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
8
9
10class EulerAScheduler(SchedulerMixin, ConfigMixin):
11 """
12 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
13 the VE column of Table 1 from [1] for reference.
14
15 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
16 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
17 differential equations." https://arxiv.org/abs/2011.13456
18
19 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
20 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
21 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
22 [`~ConfigMixin.from_config`] functions.
23
24 For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
25 Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
26 optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
27
28 Args:
29 sigma_min (`float`): minimum noise magnitude
30 sigma_max (`float`): maximum noise magnitude
31 s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
32 A reasonable range is [1.000, 1.011].
33 s_churn (`float`): the parameter controlling the overall amount of stochasticity.
34 A reasonable range is [0, 100].
35 s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
36 A reasonable range is [0, 10].
37 s_max (`float`): the end value of the sigma range where we add noise.
38 A reasonable range is [0.2, 80].
39
40 """
41
42 @register_to_config
43 def __init__(
44 self,
45 num_train_timesteps: int = 1000,
46 beta_start: float = 0.0001,
47 beta_end: float = 0.02,
48 beta_schedule: str = "linear",
49 trained_betas: Optional[np.ndarray] = None,
50 num_inference_steps=None,
51 device='cuda'
52 ):
53 if trained_betas is not None:
54 self.betas = torch.from_numpy(trained_betas).to(device)
55 if beta_schedule == "linear":
56 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device)
57 elif beta_schedule == "scaled_linear":
58 # this schedule is very specific to the latent diffusion model.
59 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps,
60 dtype=torch.float32, device=device) ** 2
61 else:
62 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
63
64 self.device = device
65
66 self.alphas = 1.0 - self.betas
67 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
68
69 # standard deviation of the initial noise distribution
70 self.init_noise_sigma = 1.0
71
72 # setable values
73 self.num_inference_steps = num_inference_steps
74 self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
75 # get sigmas
76 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
77 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
78
79 # A# take number of steps as input
80 # A# store 1) number of steps 2) timesteps 3) schedule
81
82 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
83 """
84 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
85
86 Args:
87 num_inference_steps (`int`):
88 the number of diffusion steps used when generating samples with a pre-trained model.
89 """
90
91 self.num_inference_steps = num_inference_steps
92 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
93 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
94 self.timesteps = self.sigmas[:-1]
95 self.is_scale_input_called = False
96
97 def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor:
98 """
99 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
100 current timestep.
101 Args:
102 sample (`torch.FloatTensor`): input sample
103 timestep (`int`, optional): current timestep
104 Returns:
105 `torch.FloatTensor`: scaled input sample
106 """
107 if isinstance(timestep, torch.Tensor):
108 timestep = timestep.to(self.timesteps.device)
109 if self.is_scale_input_called:
110 return sample
111 step_index = (self.timesteps == timestep).nonzero().item()
112 sigma = self.sigmas[step_index]
113 sample = sample * sigma
114 self.is_scale_input_called = True
115 return sample
116
117 def step(
118 self,
119 model_output: torch.FloatTensor,
120 timestep: Union[float, torch.FloatTensor],
121 sample: torch.FloatTensor,
122 generator: torch.Generator = None,
123 return_dict: bool = True,
124 ) -> Union[SchedulerOutput, Tuple]:
125 """
126 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
127 process from the learned model outputs (most often the predicted noise).
128
129 Args:
130 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
131 sigma_hat (`float`): TODO
132 sigma_prev (`float`): TODO
133 sample_hat (`torch.FloatTensor`): TODO
134 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
135
136 EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check).
137 Returns:
138 [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`:
139 [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When
140 returning a tuple, the first element is the sample tensor.
141
142 """
143 if isinstance(timestep, torch.Tensor):
144 timestep = timestep.to(self.timesteps.device)
145 step_index = (self.timesteps == timestep).nonzero().item()
146 step_prev_index = step_index + 1
147
148 s = self.sigmas[step_index]
149 s_prev = self.sigmas[step_prev_index]
150 latents = sample
151
152 sigma_down, sigma_up = self.get_ancestral_step(s, s_prev)
153 d = self.to_d(latents, s, model_output)
154 dt = sigma_down - s
155 latents = latents + d * dt
156 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
157 generator=generator) * sigma_up
158
159 return SchedulerOutput(prev_sample=latents)
160
161 def step_correct(
162 self,
163 model_output: torch.FloatTensor,
164 sigma_hat: float,
165 sigma_prev: float,
166 sample_hat: torch.FloatTensor,
167 sample_prev: torch.FloatTensor,
168 derivative: torch.FloatTensor,
169 return_dict: bool = True,
170 ) -> Union[SchedulerOutput, Tuple]:
171 """
172 Correct the predicted sample based on the output model_output of the network. TODO complete description
173
174 Args:
175 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
176 sigma_hat (`float`): TODO
177 sigma_prev (`float`): TODO
178 sample_hat (`torch.FloatTensor`): TODO
179 sample_prev (`torch.FloatTensor`): TODO
180 derivative (`torch.FloatTensor`): TODO
181 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
182
183 Returns:
184 prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
185
186 """
187 pred_original_sample = sample_prev + sigma_prev * model_output
188 derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
189 sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
190
191 if not return_dict:
192 return (sample_prev, derivative)
193
194 return SchedulerOutput(prev_sample=sample_prev)
195
196 def add_noise(
197 self,
198 original_samples: torch.FloatTensor,
199 noise: torch.FloatTensor,
200 timesteps: torch.FloatTensor,
201 ) -> torch.FloatTensor:
202 sigmas = self.sigmas.to(original_samples.device)
203 schedule_timesteps = self.timesteps.to(original_samples.device)
204 timesteps = timesteps.to(original_samples.device)
205 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
206
207 sigma = sigmas[step_indices].flatten()
208 while len(sigma.shape) < len(original_samples.shape):
209 sigma = sigma.unsqueeze(-1)
210
211 noisy_samples = original_samples + noise * sigma
212 self.is_scale_input_called = True
213 return noisy_samples
214
215 # from k_samplers sampling.py
216
217 def get_ancestral_step(self, sigma_from, sigma_to):
218 """Calculates the noise level (sigma_down) to step down to and the amount
219 of noise to add (sigma_up) when doing an ancestral sampling step."""
220 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
221 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
222 return sigma_down, sigma_up
223
224 def t_to_sigma(self, t, sigmas):
225 t = t.float()
226 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
227 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
228
229 def append_zero(self, x):
230 return torch.cat([x, x.new_zeros([1])])
231
232 def get_sigmas(self, sigmas, n=None):
233 if n is None:
234 return self.append_zero(sigmas.flip(0))
235 t_max = len(sigmas) - 1 # = 999
236 device = self.device
237 t = torch.linspace(t_max, 0, n, device=device)
238 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
239 return self.append_zero(self.t_to_sigma(t, sigmas))
240
241 # from k_samplers utils.py
242 def append_dims(self, x, target_dims):
243 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
244 dims_to_append = target_dims - x.ndim
245 if dims_to_append < 0:
246 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
247 return x[(...,) + (None,) * dims_to_append]
248
249 # from k_samplers sampling.py
250 def to_d(self, x, sigma, denoised):
251 """Converts a denoiser output to a Karras ODE derivative."""
252 return (x - denoised) / self.append_dims(sigma, x.ndim)
253
254 def get_scalings(self, sigma):
255 sigma_data = 1.
256 c_out = -sigma
257 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
258 return c_out, c_in
259
260 # DiscreteSchedule DS
261 def DSsigma_to_t(self, sigma, quantize=None):
262 # quantize = self.quantize if quantize is None else quantize
263 quantize = False
264 dists = torch.abs(sigma - self.DSsigmas[:, None])
265 if quantize:
266 return torch.argmin(dists, dim=0).view(sigma.shape)
267 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
268 low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
269 w = (low - sigma) / (low - high)
270 w = w.clamp(0, 1)
271 t = (1 - w) * low_idx + w * high_idx
272 return t.view(sigma.shape)
273
274 def prepare_input(self, latent_in, t, batch_size):
275 sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1
276
277 sigma_in = torch.cat([sigma] * 2 * batch_size)
278 # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
279 # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
280 c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
281
282 sigma_in = self.DSsigma_to_t(sigma_in)
283 # s_in = latent_in.new_ones([latent_in.shape[0]])
284 # sigma_in = sigma_in * s_in
285
286 return c_out, c_in, sigma_in