summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py323
1 files changed, 323 insertions, 0 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
new file mode 100644
index 0000000..57a56de
--- /dev/null
+++ b/schedulers/scheduling_euler_a.py
@@ -0,0 +1,323 @@
1
2
3import math
4import warnings
5from typing import Optional, Tuple, Union
6
7import numpy as np
8import torch
9
10from diffusers.configuration_utils import ConfigMixin, register_to_config
11from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
12
13
14'''
15helper functions: append_zero(),
16 t_to_sigma(),
17 get_sigmas(),
18 append_dims(),
19 CFGDenoiserForward(),
20 get_scalings(),
21 DSsigma_to_t(),
22 DiscreteEpsDDPMDenoiserForward(),
23 to_d(),
24 get_ancestral_step()
25need cleaning
26'''
27
28
29def append_zero(x):
30 return torch.cat([x, x.new_zeros([1])])
31
32
33def t_to_sigma(t, sigmas):
34 t = t.float()
35 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
36 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
37
38
39def get_sigmas(sigmas, n=None):
40 if n is None:
41 return append_zero(sigmas.flip(0))
42 t_max = len(sigmas) - 1 # = 999
43 t = torch.linspace(t_max, 0, n, device=sigmas.device)
44 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
45 return append_zero(t_to_sigma(t, sigmas))
46
47# from k_samplers utils.py
48
49
50def append_dims(x, target_dims):
51 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
52 dims_to_append = target_dims - x.ndim
53 if dims_to_append < 0:
54 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
55 return x[(...,) + (None,) * dims_to_append]
56
57
58def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None):
59 # x_in = torch.cat([x] * 2)#A# concat the latent
60 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
61 # cond_in = torch.cat([uncond, cond])
62 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
63 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
64 # return uncond + (cond - uncond) * cond_scale
65 noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in)
66 return noise_pred
67
68# from k_samplers sampling.py
69
70
71def to_d(x, sigma, denoised):
72 """Converts a denoiser output to a Karras ODE derivative."""
73 return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim)
74
75
76def get_scalings(sigma):
77 sigma_data = 1.
78 c_out = -sigma
79 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
80 return c_out, c_in
81
82# DiscreteSchedule DS
83
84
85def DSsigma_to_t(sigma, quantize=None, DSsigmas=None):
86 # quantize = self.quantize if quantize is None else quantize
87 quantize = False
88 dists = torch.abs(sigma - DSsigmas[:, None])
89 if quantize:
90 return torch.argmin(dists, dim=0).view(sigma.shape)
91 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
92 low, high = DSsigmas[low_idx], DSsigmas[high_idx]
93 w = (low - sigma) / (low - high)
94 w = w.clamp(0, 1)
95 t = (1 - w) * low_idx + w * high_idx
96 return t.view(sigma.shape)
97
98
99def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs):
100 sigma = sigma.to(Unet.device)
101 DSsigmas = DSsigmas.to(Unet.device)
102 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
103 # ??? what is eps?
104 # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
105 eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas),
106 encoder_hidden_states=kwargs['cond']).sample
107 return input + eps * c_out
108
109
110# from k_samplers sampling.py
111def get_ancestral_step(sigma_from, sigma_to):
112 """Calculates the noise level (sigma_down) to step down to and the amount
113 of noise to add (sigma_up) when doing an ancestral sampling step."""
114 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
115 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
116 return sigma_down, sigma_up
117
118
119'''
120Euler Ancestral Scheduler
121'''
122
123
124class EulerAScheduler(SchedulerMixin, ConfigMixin):
125 """
126 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
127 the VE column of Table 1 from [1] for reference.
128
129 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
130 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
131 differential equations." https://arxiv.org/abs/2011.13456
132
133 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
134 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
135 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
136 [`~ConfigMixin.from_config`] functions.
137
138 For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
139 Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
140 optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
141
142 Args:
143 sigma_min (`float`): minimum noise magnitude
144 sigma_max (`float`): maximum noise magnitude
145 s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
146 A reasonable range is [1.000, 1.011].
147 s_churn (`float`): the parameter controlling the overall amount of stochasticity.
148 A reasonable range is [0, 100].
149 s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
150 A reasonable range is [0, 10].
151 s_max (`float`): the end value of the sigma range where we add noise.
152 A reasonable range is [0.2, 80].
153
154 """
155
156 @register_to_config
157 def __init__(
158 self,
159 num_train_timesteps: int = 1000,
160 beta_start: float = 0.0001,
161 beta_end: float = 0.02,
162 beta_schedule: str = "linear",
163 trained_betas: Optional[np.ndarray] = None,
164 clip_sample: bool = True,
165 set_alpha_to_one: bool = True,
166 steps_offset: int = 0,
167 ):
168 if trained_betas is not None:
169 self.betas = torch.from_numpy(trained_betas)
170 if beta_schedule == "linear":
171 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
172 elif beta_schedule == "scaled_linear":
173 # this schedule is very specific to the latent diffusion model.
174 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
175 elif beta_schedule == "squaredcos_cap_v2":
176 # Glide cosine schedule
177 self.betas = betas_for_alpha_bar(num_train_timesteps)
178 else:
179 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
180
181 self.alphas = 1.0 - self.betas
182 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
183
184 # At every step in ddim, we are looking into the previous alphas_cumprod
185 # For the final step, there is no previous alphas_cumprod because we are already at 0
186 # `set_alpha_to_one` decides whether we set this parameter simply to one or
187 # whether we use the final alpha of the "non-previous" one.
188 self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
189
190 # setable values
191 self.num_inference_steps = None
192 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
193
194 # A# take number of steps as input
195 # A# store 1) number of steps 2) timesteps 3) schedule
196
197 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
198 """
199 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
200
201 Args:
202 num_inference_steps (`int`):
203 the number of diffusion steps used when generating samples with a pre-trained model.
204 """
205
206 # offset = self.config.steps_offset
207
208 # if "offset" in kwargs:
209 # warnings.warn(
210 # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
211 # " Please pass `steps_offset` to `__init__` instead.",
212 # DeprecationWarning,
213 # )
214
215 # offset = kwargs["offset"]
216
217 self.num_inference_steps = num_inference_steps
218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
220 self.timesteps = self.sigmas
221
222 def add_noise_to_input(
223 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
224 ) -> Tuple[torch.FloatTensor, float]:
225 """
226 Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
227 higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
228
229 TODO Args:
230 """
231 if self.config.s_min <= sigma <= self.config.s_max:
232 gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
233 else:
234 gamma = 0
235
236 # sample eps ~ N(0, S_noise^2 * I)
237 eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
238 sigma_hat = sigma + gamma * sigma
239 sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
240
241 return sample_hat, sigma_hat
242
243 def step(
244 self,
245 model_output: torch.FloatTensor,
246 timestep: torch.IntTensor,
247 timestep_prev: torch.IntTensor,
248 sample: torch.FloatTensor,
249 generator: None,
250 # ,sigma_hat: float,
251 # sigma_prev: float,
252 # sample_hat: torch.FloatTensor,
253 return_dict: bool = True,
254 ) -> Union[SchedulerOutput, Tuple]:
255 """
256 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
257 process from the learned model outputs (most often the predicted noise).
258
259 Args:
260 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
261 sigma_hat (`float`): TODO
262 sigma_prev (`float`): TODO
263 sample_hat (`torch.FloatTensor`): TODO
264 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
265
266 EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check).
267 Returns:
268 [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`:
269 [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When
270 returning a tuple, the first element is the sample tensor.
271
272 """
273 latents = sample
274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
275
276 # if callback is not None:
277 # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
278 d = to_d(latents, timestep, model_output)
279 # Euler method
280 dt = sigma_down - timestep
281 latents = latents + d * dt
282 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,
283 generator=generator) * sigma_up
284 return SchedulerOutput(prev_sample=latents)
285
286 def step_correct(
287 self,
288 model_output: torch.FloatTensor,
289 sigma_hat: float,
290 sigma_prev: float,
291 sample_hat: torch.FloatTensor,
292 sample_prev: torch.FloatTensor,
293 derivative: torch.FloatTensor,
294 generator: None,
295 return_dict: bool = True,
296 ) -> Union[SchedulerOutput, Tuple]:
297 """
298 Correct the predicted sample based on the output model_output of the network. TODO complete description
299
300 Args:
301 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
302 sigma_hat (`float`): TODO
303 sigma_prev (`float`): TODO
304 sample_hat (`torch.FloatTensor`): TODO
305 sample_prev (`torch.FloatTensor`): TODO
306 derivative (`torch.FloatTensor`): TODO
307 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
308
309 Returns:
310 prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
311
312 """
313 pred_original_sample = sample_prev + sigma_prev * model_output
314 derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
315 sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
316
317 if not return_dict:
318 return (sample_prev, derivative)
319
320 return SchedulerOutput(prev_sample=sample_prev)
321
322 def add_noise(self, original_samples, noise, timesteps):
323 raise NotImplementedError()