summaryrefslogtreecommitdiffstats
path: root/schedulers/scheduling_euler_ancestral_discrete.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-14 17:09:58 +0100
committerVolpeon <git@volpeon.ink>2022-11-14 17:09:58 +0100
commit2ad46871e2ead985445da2848a4eb7072b6e48aa (patch)
tree3137923e2c00fe1d3cd37ddcc93c8a847b0c0762 /schedulers/scheduling_euler_ancestral_discrete.py
parentUpdate (diff)
downloadtextual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.gz
textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.bz2
textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.zip
Update
Diffstat (limited to 'schedulers/scheduling_euler_ancestral_discrete.py')
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py261
1 files changed, 0 insertions, 261 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
deleted file mode 100644
index cef50fe..0000000
--- a/schedulers/scheduling_euler_ancestral_discrete.py
+++ /dev/null
@@ -1,261 +0,0 @@
1# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from dataclasses import dataclass
16from typing import Optional, Tuple, Union
17
18import numpy as np
19import torch
20
21from diffusers.configuration_utils import ConfigMixin, register_to_config
22from diffusers.utils import BaseOutput, deprecate, logging
23from diffusers.schedulers.scheduling_utils import SchedulerMixin
24
25
26logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
28
29@dataclass
30# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
31class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
32 """
33 Output class for the scheduler's step function output.
34
35 Args:
36 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37 Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38 denoising loop.
39 pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40 The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41 `pred_original_sample` can be used to preview progress or for guidance.
42 """
43
44 prev_sample: torch.FloatTensor
45 pred_original_sample: Optional[torch.FloatTensor] = None
46
47
48class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
49 """
50 Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
51 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
52
53 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
54 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
55 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
56 [`~ConfigMixin.from_config`] functions.
57
58 Args:
59 num_train_timesteps (`int`): number of diffusion steps used to train the model.
60 beta_start (`float`): the starting `beta` value of inference.
61 beta_end (`float`): the final `beta` value.
62 beta_schedule (`str`):
63 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
64 `linear` or `scaled_linear`.
65 trained_betas (`np.ndarray`, optional):
66 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
67
68 """
69
70 @register_to_config
71 def __init__(
72 self,
73 num_train_timesteps: int = 1000,
74 beta_start: float = 0.0001,
75 beta_end: float = 0.02,
76 beta_schedule: str = "linear",
77 trained_betas: Optional[np.ndarray] = None,
78 ):
79 if trained_betas is not None:
80 self.betas = torch.from_numpy(trained_betas)
81 elif beta_schedule == "linear":
82 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
83 elif beta_schedule == "scaled_linear":
84 # this schedule is very specific to the latent diffusion model.
85 self.betas = (
86 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
87 )
88 else:
89 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
90
91 self.alphas = 1.0 - self.betas
92 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
93
94 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
95 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
96 self.sigmas = torch.from_numpy(sigmas)
97
98 # standard deviation of the initial noise distribution
99 self.init_noise_sigma = self.sigmas.max()
100
101 # setable values
102 self.num_inference_steps = None
103 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
104 self.timesteps = torch.from_numpy(timesteps)
105 self.is_scale_input_called = False
106
107 def scale_model_input(
108 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
109 ) -> torch.FloatTensor:
110 """
111 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
112
113 Args:
114 sample (`torch.FloatTensor`): input sample
115 timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
116
117 Returns:
118 `torch.FloatTensor`: scaled input sample
119 """
120 if isinstance(timestep, torch.Tensor):
121 timestep = timestep.to(self.timesteps.device)
122 step_index = (self.timesteps == timestep).nonzero().item()
123 sigma = self.sigmas[step_index]
124 sample = sample / ((sigma**2 + 1) ** 0.5)
125 self.is_scale_input_called = True
126 return sample
127
128 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
129 """
130 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
131
132 Args:
133 num_inference_steps (`int`):
134 the number of diffusion steps used when generating samples with a pre-trained model.
135 device (`str` or `torch.device`, optional):
136 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137 """
138 self.num_inference_steps = num_inference_steps
139
140 timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
141 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
142 sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
143 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
144 self.sigmas = torch.from_numpy(sigmas).to(device=device)
145 self.timesteps = torch.from_numpy(timesteps).to(device=device)
146
147 def step(
148 self,
149 model_output: torch.FloatTensor,
150 timestep: Union[float, torch.FloatTensor],
151 sample: torch.FloatTensor,
152 generator: Optional[torch.Generator] = None,
153 return_dict: bool = True,
154 ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
155 """
156 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
157 process from the learned model outputs (most often the predicted noise).
158
159 Args:
160 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
161 timestep (`float`): current timestep in the diffusion chain.
162 sample (`torch.FloatTensor`):
163 current instance of sample being created by diffusion process.
164 generator (`torch.Generator`, optional): Random number generator.
165 return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
166
167 Returns:
168 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
169 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
170 a `tuple`. When returning a tuple, the first element is the sample tensor.
171
172 """
173
174 if (
175 isinstance(timestep, int)
176 or isinstance(timestep, torch.IntTensor)
177 or isinstance(timestep, torch.LongTensor)
178 ):
179 raise ValueError(
180 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
181 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
182 " one of the `scheduler.timesteps` as a timestep.",
183 )
184
185 if not self.is_scale_input_called:
186 logger.warn(
187 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
188 "See `StableDiffusionPipeline` for a usage example."
189 )
190
191 if isinstance(timestep, torch.Tensor):
192 timestep = timestep.to(self.timesteps.device)
193
194 step_index = (self.timesteps == timestep).nonzero().item()
195 sigma = self.sigmas[step_index]
196
197 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
198 pred_original_sample = sample - sigma * model_output
199 sigma_from = self.sigmas[step_index]
200 sigma_to = self.sigmas[step_index + 1]
201 sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
202 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
203
204 # 2. Convert to an ODE derivative
205 derivative = (sample - pred_original_sample) / sigma
206
207 dt = sigma_down - sigma
208
209 prev_sample = sample + derivative * dt
210
211 device = model_output.device if torch.is_tensor(model_output) else "cpu"
212 noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
213 prev_sample = prev_sample + noise * sigma_up
214
215 if not return_dict:
216 return (prev_sample,)
217
218 return EulerAncestralDiscreteSchedulerOutput(
219 prev_sample=prev_sample, pred_original_sample=pred_original_sample
220 )
221
222 def add_noise(
223 self,
224 original_samples: torch.FloatTensor,
225 noise: torch.FloatTensor,
226 timesteps: torch.FloatTensor,
227 ) -> torch.FloatTensor:
228 # Make sure sigmas and timesteps have the same device and dtype as original_samples
229 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
230 if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
231 # mps does not support float64
232 self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
233 timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
234 else:
235 self.timesteps = self.timesteps.to(original_samples.device)
236 timesteps = timesteps.to(original_samples.device)
237
238 schedule_timesteps = self.timesteps
239
240 if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
241 deprecate(
242 "timesteps as indices",
243 "0.8.0",
244 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
245 " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
246 " pass values from `scheduler.timesteps` as timesteps.",
247 standard_warn=False,
248 )
249 step_indices = timesteps
250 else:
251 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
252
253 sigma = self.sigmas[step_indices].flatten()
254 while len(sigma.shape) < len(original_samples.shape):
255 sigma = sigma.unsqueeze(-1)
256
257 noisy_samples = original_samples + noise * sigma
258 return noisy_samples
259
260 def __len__(self):
261 return self.config.num_train_timesteps