summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py286
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py192
2 files changed, 192 insertions, 286 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
deleted file mode 100644
index c097a8a..0000000
--- a/schedulers/scheduling_euler_a.py
+++ /dev/null
@@ -1,286 +0,0 @@
1from typing import Optional, Tuple, Union
2
3import numpy as np
4import torch
5
6from diffusers.configuration_utils import ConfigMixin, register_to_config
7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
8
9
10class EulerAScheduler(SchedulerMixin, ConfigMixin):
11 """
12 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
13 the VE column of Table 1 from [1] for reference.
14
15 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
16 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
17 differential equations." https://arxiv.org/abs/2011.13456
18
19 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
20 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
21 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
22 [`~ConfigMixin.from_config`] functions.
23
24 For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
25 Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
26 optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
27
28 Args:
29 sigma_min (`float`): minimum noise magnitude
30 sigma_max (`float`): maximum noise magnitude
31 s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
32 A reasonable range is [1.000, 1.011].
33 s_churn (`float`): the parameter controlling the overall amount of stochasticity.
34 A reasonable range is [0, 100].
35 s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
36 A reasonable range is [0, 10].
37 s_max (`float`): the end value of the sigma range where we add noise.
38 A reasonable range is [0.2, 80].
39
40 """
41
42 @register_to_config
43 def __init__(
44 self,
45 num_train_timesteps: int = 1000,
46 beta_start: float = 0.0001,
47 beta_end: float = 0.02,
48 beta_schedule: str = "linear",
49 trained_betas: Optional[np.ndarray] = None,
50 num_inference_steps=None,
51 device='cuda'
52 ):
53 if trained_betas is not None:
54 self.betas = torch.from_numpy(trained_betas).to(device)
55 if beta_schedule == "linear":
56 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device)
57 elif beta_schedule == "scaled_linear":
58 # this schedule is very specific to the latent diffusion model.
59 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps,
60 dtype=torch.float32, device=device) ** 2
61 else:
62 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
63
64 self.device = device
65
66 self.alphas = 1.0 - self.betas
67 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
68
69 # standard deviation of the initial noise distribution
70 self.init_noise_sigma = 1.0
71
72 # setable values
73 self.num_inference_steps = num_inference_steps
74 self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
75 # get sigmas
76 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
77 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
78
79 # A# take number of steps as input
80 # A# store 1) number of steps 2) timesteps 3) schedule
81
82 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
83 """
84 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
85
86 Args:
87 num_inference_steps (`int`):
88 the number of diffusion steps used when generating samples with a pre-trained model.
89 """
90
91 self.num_inference_steps = num_inference_steps
92 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
93 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
94 self.timesteps = self.sigmas[:-1]
95 self.is_scale_input_called = False
96
97 def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor:
98 """
99 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
100 current timestep.
101 Args:
102 sample (`torch.FloatTensor`): input sample
103 timestep (`int`, optional): current timestep
104 Returns:
105 `torch.FloatTensor`: scaled input sample
106 """
107 if isinstance(timestep, torch.Tensor):
108 timestep = timestep.to(self.timesteps.device)
109 if self.is_scale_input_called:
110 return sample
111 step_index = (self.timesteps == timestep).nonzero().item()
112 sigma = self.sigmas[step_index]
113 sample = sample * sigma
114 self.is_scale_input_called = True
115 return sample
116
117 def step(
118 self,
119 model_output: torch.FloatTensor,
120 timestep: Union[float, torch.FloatTensor],
121 sample: torch.FloatTensor,
122 generator: torch.Generator = None,
123 return_dict: bool = True,
124 ) -> Union[SchedulerOutput, Tuple]:
125 """
126 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
127 process from the learned model outputs (most often the predicted noise).
128
129 Args:
130 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
131 sigma_hat (`float`): TODO
132 sigma_prev (`float`): TODO
133 sample_hat (`torch.FloatTensor`): TODO
134 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
135
136 EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check).
137 Returns:
138 [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`:
139 [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When
140 returning a tuple, the first element is the sample tensor.
141
142 """
143 if isinstance(timestep, torch.Tensor):
144 timestep = timestep.to(self.timesteps.device)
145 step_index = (self.timesteps == timestep).nonzero().item()
146 step_prev_index = step_index + 1
147
148 s = self.sigmas[step_index]
149 s_prev = self.sigmas[step_prev_index]
150 latents = sample
151
152 sigma_down, sigma_up = self.get_ancestral_step(s, s_prev)
153 d = self.to_d(latents, s, model_output)
154 dt = sigma_down - s
155 latents = latents + d * dt
156 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
157 generator=generator) * sigma_up
158
159 return SchedulerOutput(prev_sample=latents)
160
161 def step_correct(
162 self,
163 model_output: torch.FloatTensor,
164 sigma_hat: float,
165 sigma_prev: float,
166 sample_hat: torch.FloatTensor,
167 sample_prev: torch.FloatTensor,
168 derivative: torch.FloatTensor,
169 return_dict: bool = True,
170 ) -> Union[SchedulerOutput, Tuple]:
171 """
172 Correct the predicted sample based on the output model_output of the network. TODO complete description
173
174 Args:
175 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
176 sigma_hat (`float`): TODO
177 sigma_prev (`float`): TODO
178 sample_hat (`torch.FloatTensor`): TODO
179 sample_prev (`torch.FloatTensor`): TODO
180 derivative (`torch.FloatTensor`): TODO
181 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
182
183 Returns:
184 prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
185
186 """
187 pred_original_sample = sample_prev + sigma_prev * model_output
188 derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
189 sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
190
191 if not return_dict:
192 return (sample_prev, derivative)
193
194 return SchedulerOutput(prev_sample=sample_prev)
195
196 def add_noise(
197 self,
198 original_samples: torch.FloatTensor,
199 noise: torch.FloatTensor,
200 timesteps: torch.FloatTensor,
201 ) -> torch.FloatTensor:
202 sigmas = self.sigmas.to(original_samples.device)
203 schedule_timesteps = self.timesteps.to(original_samples.device)
204 timesteps = timesteps.to(original_samples.device)
205 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
206
207 sigma = sigmas[step_indices].flatten()
208 while len(sigma.shape) < len(original_samples.shape):
209 sigma = sigma.unsqueeze(-1)
210
211 noisy_samples = original_samples + noise * sigma
212 self.is_scale_input_called = True
213 return noisy_samples
214
215 # from k_samplers sampling.py
216
217 def get_ancestral_step(self, sigma_from, sigma_to):
218 """Calculates the noise level (sigma_down) to step down to and the amount
219 of noise to add (sigma_up) when doing an ancestral sampling step."""
220 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
221 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
222 return sigma_down, sigma_up
223
224 def t_to_sigma(self, t, sigmas):
225 t = t.float()
226 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
227 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
228
229 def append_zero(self, x):
230 return torch.cat([x, x.new_zeros([1])])
231
232 def get_sigmas(self, sigmas, n=None):
233 if n is None:
234 return self.append_zero(sigmas.flip(0))
235 t_max = len(sigmas) - 1 # = 999
236 device = self.device
237 t = torch.linspace(t_max, 0, n, device=device)
238 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
239 return self.append_zero(self.t_to_sigma(t, sigmas))
240
241 # from k_samplers utils.py
242 def append_dims(self, x, target_dims):
243 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
244 dims_to_append = target_dims - x.ndim
245 if dims_to_append < 0:
246 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
247 return x[(...,) + (None,) * dims_to_append]
248
249 # from k_samplers sampling.py
250 def to_d(self, x, sigma, denoised):
251 """Converts a denoiser output to a Karras ODE derivative."""
252 return (x - denoised) / self.append_dims(sigma, x.ndim)
253
254 def get_scalings(self, sigma):
255 sigma_data = 1.
256 c_out = -sigma
257 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
258 return c_out, c_in
259
260 # DiscreteSchedule DS
261 def DSsigma_to_t(self, sigma, quantize=None):
262 # quantize = self.quantize if quantize is None else quantize
263 quantize = False
264 dists = torch.abs(sigma - self.DSsigmas[:, None])
265 if quantize:
266 return torch.argmin(dists, dim=0).view(sigma.shape)
267 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
268 low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
269 w = (low - sigma) / (low - high)
270 w = w.clamp(0, 1)
271 t = (1 - w) * low_idx + w * high_idx
272 return t.view(sigma.shape)
273
274 def prepare_input(self, latent_in, t, batch_size):
275 sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1
276
277 sigma_in = torch.cat([sigma] * 2 * batch_size)
278 # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
279 # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
280 c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
281
282 sigma_in = self.DSsigma_to_t(sigma_in)
283 # s_in = latent_in.new_ones([latent_in.shape[0]])
284 # sigma_in = sigma_in * s_in
285
286 return c_out, c_in, sigma_in
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
new file mode 100644
index 0000000..3a2de68
--- /dev/null
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -0,0 +1,192 @@
1# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, Tuple, Union
16
17import numpy as np
18import torch
19
20from diffusers.configuration_utils import ConfigMixin, register_to_config
21from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
22
23
24class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
25 """
26 Ancestral sampling with Euler method steps.
27 for discrete beta schedules. Based on the original k-diffusion implementation by
28 Katherine Crowson:
29 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
30
31 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
32 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
33 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
34 [`~ConfigMixin.from_config`] functions.
35
36 Args:
37 num_train_timesteps (`int`): number of diffusion steps used to train the model.
38 beta_start (`float`): the starting `beta` value of inference.
39 beta_end (`float`): the final `beta` value.
40 beta_schedule (`str`):
41 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
42 `linear` or `scaled_linear`.
43 trained_betas (`np.ndarray`, optional):
44 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
45 options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
46 `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
47 tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
48
49 """
50
51 @register_to_config
52 def __init__(
53 self,
54 num_train_timesteps: int = 1000,
55 beta_start: float = 0.00085, # sensible defaults
56 beta_end: float = 0.012,
57 beta_schedule: str = "linear",
58 trained_betas: Optional[np.ndarray] = None,
59 ):
60 if trained_betas is not None:
61 self.betas = torch.from_numpy(trained_betas)
62 elif beta_schedule == "linear":
63 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
64 elif beta_schedule == "scaled_linear":
65 # this schedule is very specific to the latent diffusion model.
66 self.betas = (
67 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
68 )
69 else:
70 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
71
72 self.alphas = 1.0 - self.betas
73 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
74
75 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
76 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
77 self.sigmas = torch.from_numpy(sigmas)
78
79 self.init_noise_sigma = None
80
81 # setable values
82 self.num_inference_steps = None
83 timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
84 self.timesteps = torch.from_numpy(timesteps)
85 self.derivatives = []
86 self.is_scale_input_called = False
87
88 def scale_model_input(
89 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor]
90 ) -> torch.FloatTensor:
91 """
92 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
93
94 Args:
95 sample (`torch.FloatTensor`): input sample
96 timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
97
98 Returns:
99 `torch.FloatTensor`: scaled input sample
100 """
101 sigma = self.sigmas[step_index]
102 sample = sample / ((sigma**2 + 1) ** 0.5)
103 return sample
104
105 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
106 """
107 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
108
109 Args:
110 num_inference_steps (`int`):
111 the number of diffusion steps used when generating samples with a pre-trained model.
112 """
113 self.num_inference_steps = num_inference_steps
114 self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
115
116 low_idx = np.floor(self.timesteps).astype(int)
117 high_idx = np.ceil(self.timesteps).astype(int)
118 frac = np.mod(self.timesteps, 1.0)
119 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
120 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
121 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
122 self.sigmas = torch.from_numpy(sigmas)
123 self.timesteps = torch.from_numpy(self.timesteps)
124 self.init_noise_sigma = self.sigmas[0]
125 self.derivatives = []
126
127 def step(
128 self,
129 model_output: Union[torch.FloatTensor, np.ndarray],
130 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor],
132 sample: Union[torch.FloatTensor, np.ndarray],
133 return_dict: bool = True,
134 ) -> Union[SchedulerOutput, Tuple]:
135 """
136 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
137 process from the learned model outputs (most often the predicted noise).
138
139 Args:
140 model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
141 timestep (`int`): current discrete timestep in the diffusion chain.
142 sample (`torch.FloatTensor` or `np.ndarray`):
143 current instance of sample being created by diffusion process.
144 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
145
146 Returns:
147 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
148 [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
149 returning a tuple, the first element is the sample tensor.
150
151 """
152 sigma = self.sigmas[step_index]
153
154 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
155 pred_original_sample = sample - sigma * model_output
156 sigma_from = self.sigmas[step_index]
157 sigma_to = self.sigmas[step_index + 1]
158 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
159 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
160 # 2. Convert to an ODE derivative
161 derivative = (sample - pred_original_sample) / sigma
162 self.derivatives.append(derivative)
163
164 dt = sigma_down - sigma
165
166 prev_sample = sample + derivative * dt
167
168 prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up
169
170 if not return_dict:
171 return (prev_sample,)
172
173 return SchedulerOutput(prev_sample=prev_sample)
174
175 def add_noise(
176 self,
177 original_samples: torch.FloatTensor,
178 noise: torch.FloatTensor,
179 timesteps: torch.IntTensor,
180 ) -> torch.FloatTensor:
181 # Make sure sigmas and timesteps have the same device and dtype as original_samples
182 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
183 self.timesteps = self.timesteps.to(original_samples.device)
184 sigma = self.sigmas[timesteps].flatten()
185 while len(sigma.shape) < len(original_samples.shape):
186 sigma = sigma.unsqueeze(-1)
187
188 noisy_samples = original_samples + noise * sigma
189 return noisy_samples
190
191 def __len__(self):
192 return self.config.num_train_timesteps