diff options
Diffstat (limited to 'schedulers/scheduling_euler_ancestral_discrete.py')
-rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000..3a2de68 --- /dev/null +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
@@ -0,0 +1,192 @@ | |||
1 | # Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. | ||
2 | # | ||
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | # you may not use this file except in compliance with the License. | ||
5 | # You may obtain a copy of the License at | ||
6 | # | ||
7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | # | ||
9 | # Unless required by applicable law or agreed to in writing, software | ||
10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | # See the License for the specific language governing permissions and | ||
13 | # limitations under the License. | ||
14 | |||
15 | from typing import Optional, Tuple, Union | ||
16 | |||
17 | import numpy as np | ||
18 | import torch | ||
19 | |||
20 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
21 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
22 | |||
23 | |||
24 | class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
25 | """ | ||
26 | Ancestral sampling with Euler method steps. | ||
27 | for discrete beta schedules. Based on the original k-diffusion implementation by | ||
28 | Katherine Crowson: | ||
29 | https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 | ||
30 | |||
31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
33 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
34 | [`~ConfigMixin.from_config`] functions. | ||
35 | |||
36 | Args: | ||
37 | num_train_timesteps (`int`): number of diffusion steps used to train the model. | ||
38 | beta_start (`float`): the starting `beta` value of inference. | ||
39 | beta_end (`float`): the final `beta` value. | ||
40 | beta_schedule (`str`): | ||
41 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | ||
42 | `linear` or `scaled_linear`. | ||
43 | trained_betas (`np.ndarray`, optional): | ||
44 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
45 | options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, | ||
46 | `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. | ||
47 | tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. | ||
48 | |||
49 | """ | ||
50 | |||
51 | @register_to_config | ||
52 | def __init__( | ||
53 | self, | ||
54 | num_train_timesteps: int = 1000, | ||
55 | beta_start: float = 0.00085, # sensible defaults | ||
56 | beta_end: float = 0.012, | ||
57 | beta_schedule: str = "linear", | ||
58 | trained_betas: Optional[np.ndarray] = None, | ||
59 | ): | ||
60 | if trained_betas is not None: | ||
61 | self.betas = torch.from_numpy(trained_betas) | ||
62 | elif beta_schedule == "linear": | ||
63 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
64 | elif beta_schedule == "scaled_linear": | ||
65 | # this schedule is very specific to the latent diffusion model. | ||
66 | self.betas = ( | ||
67 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
68 | ) | ||
69 | else: | ||
70 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
71 | |||
72 | self.alphas = 1.0 - self.betas | ||
73 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
74 | |||
75 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
76 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) | ||
77 | self.sigmas = torch.from_numpy(sigmas) | ||
78 | |||
79 | self.init_noise_sigma = None | ||
80 | |||
81 | # setable values | ||
82 | self.num_inference_steps = None | ||
83 | timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||
84 | self.timesteps = torch.from_numpy(timesteps) | ||
85 | self.derivatives = [] | ||
86 | self.is_scale_input_called = False | ||
87 | |||
88 | def scale_model_input( | ||
89 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] | ||
90 | ) -> torch.FloatTensor: | ||
91 | """ | ||
92 | Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. | ||
93 | |||
94 | Args: | ||
95 | sample (`torch.FloatTensor`): input sample | ||
96 | timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | ||
97 | |||
98 | Returns: | ||
99 | `torch.FloatTensor`: scaled input sample | ||
100 | """ | ||
101 | sigma = self.sigmas[step_index] | ||
102 | sample = sample / ((sigma**2 + 1) ** 0.5) | ||
103 | return sample | ||
104 | |||
105 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | ||
106 | """ | ||
107 | Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
108 | |||
109 | Args: | ||
110 | num_inference_steps (`int`): | ||
111 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
112 | """ | ||
113 | self.num_inference_steps = num_inference_steps | ||
114 | self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) | ||
115 | |||
116 | low_idx = np.floor(self.timesteps).astype(int) | ||
117 | high_idx = np.ceil(self.timesteps).astype(int) | ||
118 | frac = np.mod(self.timesteps, 1.0) | ||
119 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
120 | sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] | ||
121 | sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
122 | self.sigmas = torch.from_numpy(sigmas) | ||
123 | self.timesteps = torch.from_numpy(self.timesteps) | ||
124 | self.init_noise_sigma = self.sigmas[0] | ||
125 | self.derivatives = [] | ||
126 | |||
127 | def step( | ||
128 | self, | ||
129 | model_output: Union[torch.FloatTensor, np.ndarray], | ||
130 | timestep: Union[float, torch.FloatTensor], | ||
131 | step_index: Union[int, torch.IntTensor], | ||
132 | sample: Union[torch.FloatTensor, np.ndarray], | ||
133 | return_dict: bool = True, | ||
134 | ) -> Union[SchedulerOutput, Tuple]: | ||
135 | """ | ||
136 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
137 | process from the learned model outputs (most often the predicted noise). | ||
138 | |||
139 | Args: | ||
140 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. | ||
141 | timestep (`int`): current discrete timestep in the diffusion chain. | ||
142 | sample (`torch.FloatTensor` or `np.ndarray`): | ||
143 | current instance of sample being created by diffusion process. | ||
144 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
145 | |||
146 | Returns: | ||
147 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: | ||
148 | [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
149 | returning a tuple, the first element is the sample tensor. | ||
150 | |||
151 | """ | ||
152 | sigma = self.sigmas[step_index] | ||
153 | |||
154 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
155 | pred_original_sample = sample - sigma * model_output | ||
156 | sigma_from = self.sigmas[step_index] | ||
157 | sigma_to = self.sigmas[step_index + 1] | ||
158 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
159 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
160 | # 2. Convert to an ODE derivative | ||
161 | derivative = (sample - pred_original_sample) / sigma | ||
162 | self.derivatives.append(derivative) | ||
163 | |||
164 | dt = sigma_down - sigma | ||
165 | |||
166 | prev_sample = sample + derivative * dt | ||
167 | |||
168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | ||
169 | |||
170 | if not return_dict: | ||
171 | return (prev_sample,) | ||
172 | |||
173 | return SchedulerOutput(prev_sample=prev_sample) | ||
174 | |||
175 | def add_noise( | ||
176 | self, | ||
177 | original_samples: torch.FloatTensor, | ||
178 | noise: torch.FloatTensor, | ||
179 | timesteps: torch.IntTensor, | ||
180 | ) -> torch.FloatTensor: | ||
181 | # Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
182 | self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
183 | self.timesteps = self.timesteps.to(original_samples.device) | ||
184 | sigma = self.sigmas[timesteps].flatten() | ||
185 | while len(sigma.shape) < len(original_samples.shape): | ||
186 | sigma = sigma.unsqueeze(-1) | ||
187 | |||
188 | noisy_samples = original_samples + noise * sigma | ||
189 | return noisy_samples | ||
190 | |||
191 | def __len__(self): | ||
192 | return self.config.num_train_timesteps | ||