summaryrefslogtreecommitdiffstats
path: root/schedulers/scheduling_euler_ancestral_discrete.py
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers/scheduling_euler_ancestral_discrete.py')
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py192
1 files changed, 192 insertions, 0 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
new file mode 100644
index 0000000..3a2de68
--- /dev/null
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -0,0 +1,192 @@
1# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, Tuple, Union
16
17import numpy as np
18import torch
19
20from diffusers.configuration_utils import ConfigMixin, register_to_config
21from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
22
23
24class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
25 """
26 Ancestral sampling with Euler method steps.
27 for discrete beta schedules. Based on the original k-diffusion implementation by
28 Katherine Crowson:
29 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
30
31 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
32 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
33 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
34 [`~ConfigMixin.from_config`] functions.
35
36 Args:
37 num_train_timesteps (`int`): number of diffusion steps used to train the model.
38 beta_start (`float`): the starting `beta` value of inference.
39 beta_end (`float`): the final `beta` value.
40 beta_schedule (`str`):
41 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
42 `linear` or `scaled_linear`.
43 trained_betas (`np.ndarray`, optional):
44 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
45 options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
46 `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
47 tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
48
49 """
50
51 @register_to_config
52 def __init__(
53 self,
54 num_train_timesteps: int = 1000,
55 beta_start: float = 0.00085, # sensible defaults
56 beta_end: float = 0.012,
57 beta_schedule: str = "linear",
58 trained_betas: Optional[np.ndarray] = None,
59 ):
60 if trained_betas is not None:
61 self.betas = torch.from_numpy(trained_betas)
62 elif beta_schedule == "linear":
63 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
64 elif beta_schedule == "scaled_linear":
65 # this schedule is very specific to the latent diffusion model.
66 self.betas = (
67 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
68 )
69 else:
70 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
71
72 self.alphas = 1.0 - self.betas
73 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
74
75 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
76 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
77 self.sigmas = torch.from_numpy(sigmas)
78
79 self.init_noise_sigma = None
80
81 # setable values
82 self.num_inference_steps = None
83 timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
84 self.timesteps = torch.from_numpy(timesteps)
85 self.derivatives = []
86 self.is_scale_input_called = False
87
88 def scale_model_input(
89 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor]
90 ) -> torch.FloatTensor:
91 """
92 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
93
94 Args:
95 sample (`torch.FloatTensor`): input sample
96 timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
97
98 Returns:
99 `torch.FloatTensor`: scaled input sample
100 """
101 sigma = self.sigmas[step_index]
102 sample = sample / ((sigma**2 + 1) ** 0.5)
103 return sample
104
105 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
106 """
107 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
108
109 Args:
110 num_inference_steps (`int`):
111 the number of diffusion steps used when generating samples with a pre-trained model.
112 """
113 self.num_inference_steps = num_inference_steps
114 self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
115
116 low_idx = np.floor(self.timesteps).astype(int)
117 high_idx = np.ceil(self.timesteps).astype(int)
118 frac = np.mod(self.timesteps, 1.0)
119 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
120 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
121 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
122 self.sigmas = torch.from_numpy(sigmas)
123 self.timesteps = torch.from_numpy(self.timesteps)
124 self.init_noise_sigma = self.sigmas[0]
125 self.derivatives = []
126
127 def step(
128 self,
129 model_output: Union[torch.FloatTensor, np.ndarray],
130 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor],
132 sample: Union[torch.FloatTensor, np.ndarray],
133 return_dict: bool = True,
134 ) -> Union[SchedulerOutput, Tuple]:
135 """
136 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
137 process from the learned model outputs (most often the predicted noise).
138
139 Args:
140 model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
141 timestep (`int`): current discrete timestep in the diffusion chain.
142 sample (`torch.FloatTensor` or `np.ndarray`):
143 current instance of sample being created by diffusion process.
144 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
145
146 Returns:
147 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
148 [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
149 returning a tuple, the first element is the sample tensor.
150
151 """
152 sigma = self.sigmas[step_index]
153
154 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
155 pred_original_sample = sample - sigma * model_output
156 sigma_from = self.sigmas[step_index]
157 sigma_to = self.sigmas[step_index + 1]
158 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
159 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
160 # 2. Convert to an ODE derivative
161 derivative = (sample - pred_original_sample) / sigma
162 self.derivatives.append(derivative)
163
164 dt = sigma_down - sigma
165
166 prev_sample = sample + derivative * dt
167
168 prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up
169
170 if not return_dict:
171 return (prev_sample,)
172
173 return SchedulerOutput(prev_sample=prev_sample)
174
175 def add_noise(
176 self,
177 original_samples: torch.FloatTensor,
178 noise: torch.FloatTensor,
179 timesteps: torch.IntTensor,
180 ) -> torch.FloatTensor:
181 # Make sure sigmas and timesteps have the same device and dtype as original_samples
182 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
183 self.timesteps = self.timesteps.to(original_samples.device)
184 sigma = self.sigmas[timesteps].flatten()
185 while len(sigma.shape) < len(original_samples.shape):
186 sigma = sigma.unsqueeze(-1)
187
188 noisy_samples = original_samples + noise * sigma
189 return noisy_samples
190
191 def __len__(self):
192 return self.config.num_train_timesteps