summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
committerVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
commitd673760fc671d665aadae3b032f8e99f21ab986d (patch)
tree7c14a998742b19ddecac6ee25a669892b41c305e /schedulers
parentUpdate (diff)
downloadtextual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.gz
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.bz2
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.zip
Integrated WIP UniPC scheduler
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_unipc_multistep.py615
1 files changed, 615 insertions, 0 deletions
diff --git a/schedulers/scheduling_unipc_multistep.py b/schedulers/scheduling_unipc_multistep.py
new file mode 100644
index 0000000..ff5db24
--- /dev/null
+++ b/schedulers/scheduling_unipc_multistep.py
@@ -0,0 +1,615 @@
1# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
17import math
18from typing import List, Optional, Union
19
20import numpy as np
21import torch
22
23from diffusers.configuration_utils import ConfigMixin, register_to_config
24from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
26
27def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
28 """
29 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
30 (1-beta) over time from t = [0,1].
31
32 Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
33 to that part of the diffusion process.
34
35
36 Args:
37 num_diffusion_timesteps (`int`): the number of betas to produce.
38 max_beta (`float`): the maximum beta to use; use values lower than 1 to
39 prevent singularities.
40
41 Returns:
42 betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43 """
44
45 def alpha_bar(time_step):
46 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
47
48 betas = []
49 for i in range(num_diffusion_timesteps):
50 t1 = i / num_diffusion_timesteps
51 t2 = (i + 1) / num_diffusion_timesteps
52 betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
53 return torch.tensor(betas, dtype=torch.float32)
54
55
56class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
57 """
58 UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of
59 a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
60 UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional
61 sampling. It can also be applied to both noise prediction model and data prediction model. The corrector
62 UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy.
63
64 For more details, see the original paper: https://arxiv.org/abs/2302.04867
65
66 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We
67 recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
68
69 We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
70 diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
71 thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
72 stable-diffusion).
73
74 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
75 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
76 [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
77 [`~SchedulerMixin.from_pretrained`] functions.
78
79 Args:
80 num_train_timesteps (`int`): number of diffusion steps used to train the model.
81 beta_start (`float`): the starting `beta` value of inference.
82 beta_end (`float`): the final `beta` value.
83 beta_schedule (`str`):
84 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
85 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
86 trained_betas (`np.ndarray`, optional):
87 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
88 solver_order (`int`, default `2`):
89 the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
90 accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided
91 sampling, and `solver_order=3` for unconditional sampling.
92 prediction_type (`str`, default `epsilon`, optional):
93 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
94 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
95 https://imagen.research.google/video/paper.pdf)
96 thresholding (`bool`, default `False`):
97 whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
98 For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
99 use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
100 models (such as stable-diffusion).
101 dynamic_thresholding_ratio (`float`, default `0.995`):
102 the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
103 (https://arxiv.org/abs/2205.11487).
104 sample_max_value (`float`, default `1.0`):
105 the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
106 `predict_x0=True`.
107 predict_x0 (`bool`, default `True`):
108 whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
109 solver_type (`str`, default `bh1`):
110 the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use
111 `bh2` otherwise.
112 lower_order_final (`bool`, default `True`):
113 whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
114 find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
115 disable_corrector (`list`, default `[]`):
116 decide which step to disable the corrector. For large guidance scale, the misalignment between the
117 `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be
118 mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0])
119 solver_p (`SchedulerMixin`):
120 can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
121 """
122
123 _compatibles = [e.name for e in KarrasDiffusionSchedulers]
124 order = 1
125
126 @register_to_config
127 def __init__(
128 self,
129 num_train_timesteps: int = 1000,
130 beta_start: float = 0.0001,
131 beta_end: float = 0.02,
132 beta_schedule: str = "linear",
133 trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
134 solver_order: int = 2,
135 prediction_type: str = "epsilon",
136 thresholding: bool = False,
137 dynamic_thresholding_ratio: float = 0.995,
138 sample_max_value: float = 1.0,
139 predict_x0: bool = True,
140 solver_type: str = "bh1",
141 lower_order_final: bool = True,
142 disable_corrector: List[int] = [],
143 solver_p: SchedulerMixin = None,
144 ):
145 if trained_betas is not None:
146 self.betas = torch.tensor(trained_betas, dtype=torch.float32)
147 elif beta_schedule == "linear":
148 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
149 elif beta_schedule == "scaled_linear":
150 # this schedule is very specific to the latent diffusion model.
151 self.betas = (
152 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
153 )
154 elif beta_schedule == "squaredcos_cap_v2":
155 # Glide cosine schedule
156 self.betas = betas_for_alpha_bar(num_train_timesteps)
157 else:
158 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
159
160 self.alphas = 1.0 - self.betas
161 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
162 # Currently we only support VP-type noise schedule
163 self.alpha_t = torch.sqrt(self.alphas_cumprod)
164 self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
165 self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
166
167 # standard deviation of the initial noise distribution
168 self.init_noise_sigma = 1.0
169
170 if solver_type not in ["bh1", "bh2"]:
171 raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
172
173 self.predict_x0 = predict_x0
174 # setable values
175 self.num_inference_steps = None
176 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
177 self.timesteps = torch.from_numpy(timesteps)
178 self.model_outputs = [None] * solver_order
179 self.timestep_list = [None] * solver_order
180 self.lower_order_nums = 0
181 self.disable_corrector = disable_corrector
182 self.solver_p = solver_p
183
184 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
185 """
186 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
187
188 Args:
189 num_inference_steps (`int`):
190 the number of diffusion steps used when generating samples with a pre-trained model.
191 device (`str` or `torch.device`, optional):
192 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
193 """
194 self.num_inference_steps = num_inference_steps
195 timesteps = (
196 np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
197 .round()[::-1][:-1]
198 .copy()
199 .astype(np.int64)
200 )
201 self.timesteps = torch.from_numpy(timesteps).to(device)
202 self.model_outputs = [
203 None,
204 ] * self.config.solver_order
205 self.lower_order_nums = 0
206 if self.solver_p:
207 self.solver_p.set_timesteps(num_inference_steps, device=device)
208
209 def convert_model_output(
210 self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
211 ):
212 r"""
213 Convert the model output to the corresponding type that the algorithm PC needs.
214
215 Args:
216 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
217 timestep (`int`): current discrete timestep in the diffusion chain.
218 sample (`torch.FloatTensor`):
219 current instance of sample being created by diffusion process.
220
221 Returns:
222 `torch.FloatTensor`: the converted model output.
223 """
224 if self.predict_x0:
225 if self.config.prediction_type == "epsilon":
226 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
227 x0_pred = (sample - sigma_t * model_output) / alpha_t
228 elif self.config.prediction_type == "sample":
229 x0_pred = model_output
230 elif self.config.prediction_type == "v_prediction":
231 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
232 x0_pred = alpha_t * sample - sigma_t * model_output
233 else:
234 raise ValueError(
235 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
236 " `v_prediction` for the DPMSolverMultistepScheduler."
237 )
238
239 if self.config.thresholding:
240 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
241 orig_dtype = x0_pred.dtype
242 if orig_dtype not in [torch.float, torch.double]:
243 x0_pred = x0_pred.float()
244 dynamic_max_val = torch.quantile(
245 torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
246 )
247 dynamic_max_val = torch.maximum(
248 dynamic_max_val,
249 self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
250 )[(...,) + (None,) * (x0_pred.ndim - 1)]
251 x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
252 x0_pred = x0_pred.type(orig_dtype)
253 return x0_pred
254 else:
255 if self.config.prediction_type == "epsilon":
256 return model_output
257 elif self.config.prediction_type == "sample":
258 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
259 epsilon = (sample - alpha_t * model_output) / sigma_t
260 return epsilon
261 elif self.config.prediction_type == "v_prediction":
262 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
263 epsilon = alpha_t * model_output + sigma_t * sample
264 return epsilon
265 else:
266 raise ValueError(
267 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
268 " `v_prediction` for the DPMSolverMultistepScheduler."
269 )
270
271 def multistep_uni_p_bh_update(
272 self,
273 model_output: torch.FloatTensor,
274 prev_timestep: int,
275 sample: torch.FloatTensor,
276 order: int,
277 ):
278 """
279 One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
280
281 Args:
282 model_output (`torch.FloatTensor`):
283 direct outputs from learned diffusion model at the current timestep.
284 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
285 sample (`torch.FloatTensor`):
286 current instance of sample being created by diffusion process.
287 order (`int`): the order of UniP at this step, also the p in UniPC-p.
288
289 Returns:
290 `torch.FloatTensor`: the sample tensor at the previous timestep.
291 """
292 timestep_list = self.timestep_list
293 model_output_list = self.model_outputs
294
295 s0, t = self.timestep_list[-1], prev_timestep
296 m0 = model_output_list[-1]
297 x = sample
298
299 if self.solver_p:
300 x_t = self.solver_p.step(model_output, s0, x).prev_sample
301 return x_t
302
303 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
304 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
305 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
306
307 h = lambda_t - lambda_s0
308 device = sample.device
309
310 rks = []
311 D1s = []
312 for i in range(1, order):
313 si = timestep_list[-(i + 1)]
314 mi = model_output_list[-(i + 1)]
315 lambda_si = self.lambda_t[si]
316 rk = ((lambda_si - lambda_s0) / h)
317 rks.append(rk)
318 D1s.append((mi - m0) / rk)
319
320 rks.append(1.)
321 rks = torch.tensor(rks, device=device)
322
323 R = []
324 b = []
325
326 hh = -h if self.predict_x0 else h
327 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
328 h_phi_k = h_phi_1 / hh - 1
329
330 factorial_i = 1
331
332 if self.config.solver_type == 'bh1':
333 B_h = hh
334 elif self.config.solver_type == 'bh2':
335 B_h = torch.expm1(hh)
336 else:
337 raise NotImplementedError()
338
339 for i in range(1, order + 1):
340 R.append(torch.pow(rks, i - 1))
341 b.append(h_phi_k * factorial_i / B_h)
342 factorial_i *= (i + 1)
343 h_phi_k = h_phi_k / hh - 1 / factorial_i
344
345 R = torch.stack(R)
346 b = torch.tensor(b, device=device)
347
348 if len(D1s) > 0:
349 D1s = torch.stack(D1s, dim=1) # (B, K)
350 # for order 2, we use a simplified version
351 if order == 2:
352 rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
353 else:
354 rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
355 else:
356 D1s = None
357
358 if self.predict_x0:
359 x_t_ = (
360 sigma_t / sigma_s0 * x
361 - alpha_t * h_phi_1 * m0
362 )
363 if D1s is not None:
364 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
365 else:
366 pred_res = 0
367 x_t = x_t_ - alpha_t * B_h * pred_res
368 else:
369 x_t_ = (
370 alpha_t / alpha_s0 * x
371 - sigma_t * h_phi_1 * m0
372 )
373 if D1s is not None:
374 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
375 else:
376 pred_res = 0
377 x_t = x_t_ - sigma_t * B_h * pred_res
378
379 x_t = x_t.to(x.dtype)
380 return x_t
381
382 def multistep_uni_c_bh_update(
383 self,
384 this_model_output: torch.FloatTensor,
385 this_timestep: int,
386 last_sample: torch.FloatTensor,
387 this_sample: torch.FloatTensor,
388 order: int,
389 ):
390 """
391 One step for the UniC (B(h) version).
392
393 Args:
394 this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
395 this_timestep (`int`): the current timestep `t`
396 last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
397 this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
398 order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
399 should be order + 1
400
401 Returns:
402 `torch.FloatTensor`: the corrected sample tensor at the current timestep.
403 """
404 timestep_list = self.timestep_list
405 model_output_list = self.model_outputs
406
407 s0, t = timestep_list[-1], this_timestep
408 m0 = model_output_list[-1]
409 x = last_sample
410 x_t = this_sample
411 model_t = this_model_output
412
413 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
414 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
415 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
416
417 h = lambda_t - lambda_s0
418 device = this_sample.device
419
420 rks = []
421 D1s = []
422 for i in range(1, order):
423 si = timestep_list[-(i + 1)]
424 mi = model_output_list[-(i + 1)]
425 lambda_si = self.lambda_t[si]
426 rk = ((lambda_si - lambda_s0) / h)
427 rks.append(rk)
428 D1s.append((mi - m0) / rk)
429
430 rks.append(1.)
431 rks = torch.tensor(rks, device=device)
432
433 R = []
434 b = []
435
436 hh = -h if self.predict_x0 else h
437 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
438 h_phi_k = h_phi_1 / hh - 1
439
440 factorial_i = 1
441
442 if self.config.solver_type == 'bh1':
443 B_h = hh
444 elif self.config.solver_type == 'bh2':
445 B_h = torch.expm1(hh)
446 else:
447 raise NotImplementedError()
448
449 for i in range(1, order + 1):
450 R.append(torch.pow(rks, i - 1))
451 b.append(h_phi_k * factorial_i / B_h)
452 factorial_i *= (i + 1)
453 h_phi_k = h_phi_k / hh - 1 / factorial_i
454
455 R = torch.stack(R)
456 b = torch.tensor(b, device=device)
457
458 if len(D1s) > 0:
459 D1s = torch.stack(D1s, dim=1)
460 else:
461 D1s = None
462
463 # for order 1, we use a simplified version
464 if order == 1:
465 rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
466 else:
467 rhos_c = torch.linalg.solve(R, b)
468
469 if self.predict_x0:
470 x_t_ = (
471 sigma_t / sigma_s0 * x
472 - alpha_t * h_phi_1 * m0
473 )
474 if D1s is not None:
475 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
476 else:
477 corr_res = 0
478 D1_t = (model_t - m0)
479 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
480 else:
481 x_t_ = (
482 alpha_t / alpha_s0 * x
483 - sigma_t * h_phi_1 * m0
484 )
485 if D1s is not None:
486 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
487 else:
488 corr_res = 0
489 D1_t = (model_t - m0)
490 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
491 x_t = x_t.to(x.dtype)
492 return x_t
493
494 def step(
495 self,
496 model_output: torch.FloatTensor,
497 timestep: int,
498 sample: torch.FloatTensor,
499 return_dict: bool = True,
500 ):
501 # -> Union[SchedulerOutput, Tuple]:
502 """
503 Step function propagating the sample with the multistep UniPC.
504
505 Args:
506 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
507 timestep (`int`): current discrete timestep in the diffusion chain.
508 sample (`torch.FloatTensor`):
509 current instance of sample being created by diffusion process.
510 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
511
512 Returns:
513 [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
514 True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
515
516 """
517
518 if self.num_inference_steps is None:
519 raise ValueError(
520 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
521 )
522
523 if isinstance(timestep, torch.Tensor):
524 timestep = timestep.to(self.timesteps.device)
525 step_index = (self.timesteps == timestep).nonzero()
526 if len(step_index) == 0:
527 step_index = len(self.timesteps) - 1
528 else:
529 step_index = step_index.item()
530
531 use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector # step_index not in self.disable_corrector
532
533 model_output_convert = self.convert_model_output(model_output, timestep, sample)
534 if use_corrector:
535 sample = self.multistep_uni_c_bh_update(
536 this_model_output=model_output_convert,
537 this_timestep=timestep,
538 last_sample=self.last_sample,
539 this_sample=sample,
540 order=self.this_order,
541 )
542
543 # now prepare to run the predictor
544 prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
545
546 for i in range(self.config.solver_order - 1):
547 self.model_outputs[i] = self.model_outputs[i + 1]
548 self.timestep_list[i] = self.timestep_list[i + 1]
549
550 self.model_outputs[-1] = model_output_convert
551 self.timestep_list[-1] = timestep
552
553 if self.config.lower_order_final:
554 this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
555 else:
556 this_order = self.config.solver_order
557
558 self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
559 assert self.this_order > 0
560
561 self.last_sample = sample
562 prev_sample = self.multistep_uni_p_bh_update(
563 model_output=model_output, # pass the original non-converted model output, in case solver-p is used
564 prev_timestep=prev_timestep,
565 sample=sample,
566 order=self.this_order,
567 )
568
569 if self.lower_order_nums < self.config.solver_order:
570 self.lower_order_nums += 1
571
572 if not return_dict:
573 return (prev_sample,)
574
575 return SchedulerOutput(prev_sample=prev_sample)
576
577 def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs): # -> torch.FloatTensor:
578 """
579 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
580 current timestep.
581
582 Args:
583 sample (`torch.FloatTensor`): input sample
584
585 Returns:
586 `torch.FloatTensor`: scaled input sample
587 """
588 return sample
589
590 def add_noise(
591 self,
592 original_samples: torch.FloatTensor,
593 noise: torch.FloatTensor,
594 timesteps: torch.IntTensor,
595 ):
596 # -> torch.FloatTensor:
597 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
598 self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
599 timesteps = timesteps.to(original_samples.device)
600
601 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
602 sqrt_alpha_prod = sqrt_alpha_prod.flatten()
603 while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
604 sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
605
606 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
607 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
608 while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
609 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
610
611 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
612 return noisy_samples
613
614 def __len__(self):
615 return self.config.num_train_timesteps