summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py162
1 files changed, 112 insertions, 50 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
index 828e0dd..cef50fe 100644
--- a/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -1,4 +1,4 @@
1# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. 1# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
2# 2#
3# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License. 4# you may not use this file except in compliance with the License.
@@ -12,20 +12,42 @@
12# See the License for the specific language governing permissions and 12# See the License for the specific language governing permissions and
13# limitations under the License. 13# limitations under the License.
14 14
15from dataclasses import dataclass
15from typing import Optional, Tuple, Union 16from typing import Optional, Tuple, Union
16 17
17import numpy as np 18import numpy as np
18import torch 19import torch
19 20
20from diffusers.configuration_utils import ConfigMixin, register_to_config 21from diffusers.configuration_utils import ConfigMixin, register_to_config
21from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput 22from diffusers.utils import BaseOutput, deprecate, logging
23from diffusers.schedulers.scheduling_utils import SchedulerMixin
24
25
26logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
28
29@dataclass
30# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
31class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
32 """
33 Output class for the scheduler's step function output.
34
35 Args:
36 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37 Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38 denoising loop.
39 pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40 The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41 `pred_original_sample` can be used to preview progress or for guidance.
42 """
43
44 prev_sample: torch.FloatTensor
45 pred_original_sample: Optional[torch.FloatTensor] = None
22 46
23 47
24class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): 48class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
25 """ 49 """
26 Ancestral sampling with Euler method steps. 50 Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
27 for discrete beta schedules. Based on the original k-diffusion implementation by
28 Katherine Crowson:
29 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 51 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
30 52
31 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 53 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
@@ -42,9 +64,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
42 `linear` or `scaled_linear`. 64 `linear` or `scaled_linear`.
43 trained_betas (`np.ndarray`, optional): 65 trained_betas (`np.ndarray`, optional):
44 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. 66 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
45 options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
46 `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
47 tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
48 67
49 """ 68 """
50 69
@@ -52,8 +71,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
52 def __init__( 71 def __init__(
53 self, 72 self,
54 num_train_timesteps: int = 1000, 73 num_train_timesteps: int = 1000,
55 beta_start: float = 0.00085, # sensible defaults 74 beta_start: float = 0.0001,
56 beta_end: float = 0.012, 75 beta_end: float = 0.02,
57 beta_schedule: str = "linear", 76 beta_schedule: str = "linear",
58 trained_betas: Optional[np.ndarray] = None, 77 trained_betas: Optional[np.ndarray] = None,
59 ): 78 ):
@@ -76,20 +95,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
76 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) 95 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
77 self.sigmas = torch.from_numpy(sigmas) 96 self.sigmas = torch.from_numpy(sigmas)
78 97
79 self.init_noise_sigma = None 98 # standard deviation of the initial noise distribution
99 self.init_noise_sigma = self.sigmas.max()
80 100
81 # setable values 101 # setable values
82 self.num_inference_steps = None 102 self.num_inference_steps = None
83 timesteps = np.arange(0, num_train_timesteps)[::-1].copy() 103 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
84 self.timesteps = torch.from_numpy(timesteps) 104 self.timesteps = torch.from_numpy(timesteps)
85 self.derivatives = []
86 self.is_scale_input_called = False 105 self.is_scale_input_called = False
87 106
88 def scale_model_input( 107 def scale_model_input(
89 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] 108 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
90 ) -> torch.FloatTensor: 109 ) -> torch.FloatTensor:
91 """ 110 """
92 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. 111 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
93 112
94 Args: 113 Args:
95 sample (`torch.FloatTensor`): input sample 114 sample (`torch.FloatTensor`): input sample
@@ -98,8 +117,12 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
98 Returns: 117 Returns:
99 `torch.FloatTensor`: scaled input sample 118 `torch.FloatTensor`: scaled input sample
100 """ 119 """
120 if isinstance(timestep, torch.Tensor):
121 timestep = timestep.to(self.timesteps.device)
122 step_index = (self.timesteps == timestep).nonzero().item()
101 sigma = self.sigmas[step_index] 123 sigma = self.sigmas[step_index]
102 sample = sample / ((sigma**2 + 1) ** 0.5) 124 sample = sample / ((sigma**2 + 1) ** 0.5)
125 self.is_scale_input_called = True
103 return sample 126 return sample
104 127
105 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 128 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@@ -109,86 +132,125 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
109 Args: 132 Args:
110 num_inference_steps (`int`): 133 num_inference_steps (`int`):
111 the number of diffusion steps used when generating samples with a pre-trained model. 134 the number of diffusion steps used when generating samples with a pre-trained model.
135 device (`str` or `torch.device`, optional):
136 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
112 """ 137 """
113 self.num_inference_steps = num_inference_steps 138 self.num_inference_steps = num_inference_steps
114 self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
115 139
116 low_idx = np.floor(self.timesteps).astype(int) 140 timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
117 high_idx = np.ceil(self.timesteps).astype(int)
118 frac = np.mod(self.timesteps, 1.0)
119 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 141 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
120 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] 142 sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
121 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) 143 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
122 self.sigmas = torch.from_numpy(sigmas) 144 self.sigmas = torch.from_numpy(sigmas).to(device=device)
123 self.timesteps = torch.from_numpy(self.timesteps) 145 self.timesteps = torch.from_numpy(timesteps).to(device=device)
124 self.init_noise_sigma = self.sigmas[0]
125 self.derivatives = []
126 146
127 def step( 147 def step(
128 self, 148 self,
129 model_output: Union[torch.FloatTensor, np.ndarray], 149 model_output: torch.FloatTensor,
130 timestep: Union[float, torch.FloatTensor], 150 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor], 151 sample: torch.FloatTensor,
132 sample: Union[torch.FloatTensor, np.ndarray], 152 generator: Optional[torch.Generator] = None,
133 generator: torch.Generator = None,
134 return_dict: bool = True, 153 return_dict: bool = True,
135 ) -> Union[SchedulerOutput, Tuple]: 154 ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
136 """ 155 """
137 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 156 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
138 process from the learned model outputs (most often the predicted noise). 157 process from the learned model outputs (most often the predicted noise).
139 158
140 Args: 159 Args:
141 model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. 160 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
142 timestep (`int`): current discrete timestep in the diffusion chain. 161 timestep (`float`): current timestep in the diffusion chain.
143 sample (`torch.FloatTensor` or `np.ndarray`): 162 sample (`torch.FloatTensor`):
144 current instance of sample being created by diffusion process. 163 current instance of sample being created by diffusion process.
145 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class 164 generator (`torch.Generator`, optional): Random number generator.
165 return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
146 166
147 Returns: 167 Returns:
148 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: 168 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
149 [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When 169 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
150 returning a tuple, the first element is the sample tensor. 170 a `tuple`. When returning a tuple, the first element is the sample tensor.
151 171
152 """ 172 """
173
174 if (
175 isinstance(timestep, int)
176 or isinstance(timestep, torch.IntTensor)
177 or isinstance(timestep, torch.LongTensor)
178 ):
179 raise ValueError(
180 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
181 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
182 " one of the `scheduler.timesteps` as a timestep.",
183 )
184
185 if not self.is_scale_input_called:
186 logger.warn(
187 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
188 "See `StableDiffusionPipeline` for a usage example."
189 )
190
191 if isinstance(timestep, torch.Tensor):
192 timestep = timestep.to(self.timesteps.device)
193
194 step_index = (self.timesteps == timestep).nonzero().item()
153 sigma = self.sigmas[step_index] 195 sigma = self.sigmas[step_index]
154 196
155 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 197 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
156 pred_original_sample = sample - sigma * model_output 198 pred_original_sample = sample - sigma * model_output
157 sigma_from = self.sigmas[step_index] 199 sigma_from = self.sigmas[step_index]
158 sigma_to = self.sigmas[step_index + 1] 200 sigma_to = self.sigmas[step_index + 1]
159 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 201 sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
160 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 202 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
203
161 # 2. Convert to an ODE derivative 204 # 2. Convert to an ODE derivative
162 derivative = (sample - pred_original_sample) / sigma 205 derivative = (sample - pred_original_sample) / sigma
163 self.derivatives.append(derivative)
164 206
165 dt = sigma_down - sigma 207 dt = sigma_down - sigma
166 208
167 prev_sample = sample + derivative * dt 209 prev_sample = sample + derivative * dt
168 210
169 prev_sample = prev_sample + torch.randn( 211 device = model_output.device if torch.is_tensor(model_output) else "cpu"
170 prev_sample.shape, 212 noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
171 layout=prev_sample.layout, 213 prev_sample = prev_sample + noise * sigma_up
172 device=prev_sample.device,
173 dtype=prev_sample.dtype,
174 generator=generator
175 ) * sigma_up
176 214
177 if not return_dict: 215 if not return_dict:
178 return (prev_sample,) 216 return (prev_sample,)
179 217
180 return SchedulerOutput(prev_sample=prev_sample) 218 return EulerAncestralDiscreteSchedulerOutput(
219 prev_sample=prev_sample, pred_original_sample=pred_original_sample
220 )
181 221
182 def add_noise( 222 def add_noise(
183 self, 223 self,
184 original_samples: torch.FloatTensor, 224 original_samples: torch.FloatTensor,
185 noise: torch.FloatTensor, 225 noise: torch.FloatTensor,
186 timesteps: torch.IntTensor, 226 timesteps: torch.FloatTensor,
187 ) -> torch.FloatTensor: 227 ) -> torch.FloatTensor:
188 # Make sure sigmas and timesteps have the same device and dtype as original_samples 228 # Make sure sigmas and timesteps have the same device and dtype as original_samples
189 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 229 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
190 self.timesteps = self.timesteps.to(original_samples.device) 230 if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
191 sigma = self.sigmas[timesteps].flatten() 231 # mps does not support float64
232 self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
233 timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
234 else:
235 self.timesteps = self.timesteps.to(original_samples.device)
236 timesteps = timesteps.to(original_samples.device)
237
238 schedule_timesteps = self.timesteps
239
240 if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
241 deprecate(
242 "timesteps as indices",
243 "0.8.0",
244 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
245 " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
246 " pass values from `scheduler.timesteps` as timesteps.",
247 standard_warn=False,
248 )
249 step_indices = timesteps
250 else:
251 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
252
253 sigma = self.sigmas[step_indices].flatten()
192 while len(sigma.shape) < len(original_samples.shape): 254 while len(sigma.shape) < len(original_samples.shape):
193 sigma = sigma.unsqueeze(-1) 255 sigma = sigma.unsqueeze(-1)
194 256