diff options
author | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
commit | 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch) | |
tree | ad186862f5095663966dd1d42455023080aa0c4e /schedulers | |
parent | Better sample file structure (diff) | |
download | textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2 textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip |
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py new file mode 100644 index 0000000..57a56de --- /dev/null +++ b/schedulers/scheduling_euler_a.py | |||
@@ -0,0 +1,323 @@ | |||
1 | |||
2 | |||
3 | import math | ||
4 | import warnings | ||
5 | from typing import Optional, Tuple, Union | ||
6 | |||
7 | import numpy as np | ||
8 | import torch | ||
9 | |||
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
11 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
12 | |||
13 | |||
14 | ''' | ||
15 | helper functions: append_zero(), | ||
16 | t_to_sigma(), | ||
17 | get_sigmas(), | ||
18 | append_dims(), | ||
19 | CFGDenoiserForward(), | ||
20 | get_scalings(), | ||
21 | DSsigma_to_t(), | ||
22 | DiscreteEpsDDPMDenoiserForward(), | ||
23 | to_d(), | ||
24 | get_ancestral_step() | ||
25 | need cleaning | ||
26 | ''' | ||
27 | |||
28 | |||
29 | def append_zero(x): | ||
30 | return torch.cat([x, x.new_zeros([1])]) | ||
31 | |||
32 | |||
33 | def t_to_sigma(t, sigmas): | ||
34 | t = t.float() | ||
35 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
36 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
37 | |||
38 | |||
39 | def get_sigmas(sigmas, n=None): | ||
40 | if n is None: | ||
41 | return append_zero(sigmas.flip(0)) | ||
42 | t_max = len(sigmas) - 1 # = 999 | ||
43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
45 | return append_zero(t_to_sigma(t, sigmas)) | ||
46 | |||
47 | # from k_samplers utils.py | ||
48 | |||
49 | |||
50 | def append_dims(x, target_dims): | ||
51 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
52 | dims_to_append = target_dims - x.ndim | ||
53 | if dims_to_append < 0: | ||
54 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
55 | return x[(...,) + (None,) * dims_to_append] | ||
56 | |||
57 | |||
58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | ||
59 | # x_in = torch.cat([x] * 2)#A# concat the latent | ||
60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | ||
61 | # cond_in = torch.cat([uncond, cond]) | ||
62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | ||
63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | ||
64 | # return uncond + (cond - uncond) * cond_scale | ||
65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | ||
66 | return noise_pred | ||
67 | |||
68 | # from k_samplers sampling.py | ||
69 | |||
70 | |||
71 | def to_d(x, sigma, denoised): | ||
72 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
73 | return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) | ||
74 | |||
75 | |||
76 | def get_scalings(sigma): | ||
77 | sigma_data = 1. | ||
78 | c_out = -sigma | ||
79 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
80 | return c_out, c_in | ||
81 | |||
82 | # DiscreteSchedule DS | ||
83 | |||
84 | |||
85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | ||
86 | # quantize = self.quantize if quantize is None else quantize | ||
87 | quantize = False | ||
88 | dists = torch.abs(sigma - DSsigmas[:, None]) | ||
89 | if quantize: | ||
90 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
91 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
92 | low, high = DSsigmas[low_idx], DSsigmas[high_idx] | ||
93 | w = (low - sigma) / (low - high) | ||
94 | w = w.clamp(0, 1) | ||
95 | t = (1 - w) * low_idx + w * high_idx | ||
96 | return t.view(sigma.shape) | ||
97 | |||
98 | |||
99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | ||
100 | sigma = sigma.to(Unet.device) | ||
101 | DSsigmas = DSsigmas.to(Unet.device) | ||
102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | ||
103 | # ??? what is eps? | ||
104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
106 | encoder_hidden_states=kwargs['cond']).sample | ||
107 | return input + eps * c_out | ||
108 | |||
109 | |||
110 | # from k_samplers sampling.py | ||
111 | def get_ancestral_step(sigma_from, sigma_to): | ||
112 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
113 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
114 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
115 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
116 | return sigma_down, sigma_up | ||
117 | |||
118 | |||
119 | ''' | ||
120 | Euler Ancestral Scheduler | ||
121 | ''' | ||
122 | |||
123 | |||
124 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | ||
125 | """ | ||
126 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | ||
127 | the VE column of Table 1 from [1] for reference. | ||
128 | |||
129 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." | ||
130 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic | ||
131 | differential equations." https://arxiv.org/abs/2011.13456 | ||
132 | |||
133 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
134 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
135 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
136 | [`~ConfigMixin.from_config`] functions. | ||
137 | |||
138 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of | ||
139 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the | ||
140 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. | ||
141 | |||
142 | Args: | ||
143 | sigma_min (`float`): minimum noise magnitude | ||
144 | sigma_max (`float`): maximum noise magnitude | ||
145 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
146 | A reasonable range is [1.000, 1.011]. | ||
147 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
148 | A reasonable range is [0, 100]. | ||
149 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). | ||
150 | A reasonable range is [0, 10]. | ||
151 | s_max (`float`): the end value of the sigma range where we add noise. | ||
152 | A reasonable range is [0.2, 80]. | ||
153 | |||
154 | """ | ||
155 | |||
156 | @register_to_config | ||
157 | def __init__( | ||
158 | self, | ||
159 | num_train_timesteps: int = 1000, | ||
160 | beta_start: float = 0.0001, | ||
161 | beta_end: float = 0.02, | ||
162 | beta_schedule: str = "linear", | ||
163 | trained_betas: Optional[np.ndarray] = None, | ||
164 | clip_sample: bool = True, | ||
165 | set_alpha_to_one: bool = True, | ||
166 | steps_offset: int = 0, | ||
167 | ): | ||
168 | if trained_betas is not None: | ||
169 | self.betas = torch.from_numpy(trained_betas) | ||
170 | if beta_schedule == "linear": | ||
171 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
172 | elif beta_schedule == "scaled_linear": | ||
173 | # this schedule is very specific to the latent diffusion model. | ||
174 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
175 | elif beta_schedule == "squaredcos_cap_v2": | ||
176 | # Glide cosine schedule | ||
177 | self.betas = betas_for_alpha_bar(num_train_timesteps) | ||
178 | else: | ||
179 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
180 | |||
181 | self.alphas = 1.0 - self.betas | ||
182 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
183 | |||
184 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
185 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
186 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
187 | # whether we use the final alpha of the "non-previous" one. | ||
188 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
189 | |||
190 | # setable values | ||
191 | self.num_inference_steps = None | ||
192 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | ||
193 | |||
194 | # A# take number of steps as input | ||
195 | # A# store 1) number of steps 2) timesteps 3) schedule | ||
196 | |||
197 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): | ||
198 | """ | ||
199 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
200 | |||
201 | Args: | ||
202 | num_inference_steps (`int`): | ||
203 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
204 | """ | ||
205 | |||
206 | # offset = self.config.steps_offset | ||
207 | |||
208 | # if "offset" in kwargs: | ||
209 | # warnings.warn( | ||
210 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
211 | # " Please pass `steps_offset` to `__init__` instead.", | ||
212 | # DeprecationWarning, | ||
213 | # ) | ||
214 | |||
215 | # offset = kwargs["offset"] | ||
216 | |||
217 | self.num_inference_steps = num_inference_steps | ||
218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | ||
220 | self.timesteps = self.sigmas | ||
221 | |||
222 | def add_noise_to_input( | ||
223 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
224 | ) -> Tuple[torch.FloatTensor, float]: | ||
225 | """ | ||
226 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
227 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
228 | |||
229 | TODO Args: | ||
230 | """ | ||
231 | if self.config.s_min <= sigma <= self.config.s_max: | ||
232 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
233 | else: | ||
234 | gamma = 0 | ||
235 | |||
236 | # sample eps ~ N(0, S_noise^2 * I) | ||
237 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
238 | sigma_hat = sigma + gamma * sigma | ||
239 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
240 | |||
241 | return sample_hat, sigma_hat | ||
242 | |||
243 | def step( | ||
244 | self, | ||
245 | model_output: torch.FloatTensor, | ||
246 | timestep: torch.IntTensor, | ||
247 | timestep_prev: torch.IntTensor, | ||
248 | sample: torch.FloatTensor, | ||
249 | generator: None, | ||
250 | # ,sigma_hat: float, | ||
251 | # sigma_prev: float, | ||
252 | # sample_hat: torch.FloatTensor, | ||
253 | return_dict: bool = True, | ||
254 | ) -> Union[SchedulerOutput, Tuple]: | ||
255 | """ | ||
256 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
257 | process from the learned model outputs (most often the predicted noise). | ||
258 | |||
259 | Args: | ||
260 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
261 | sigma_hat (`float`): TODO | ||
262 | sigma_prev (`float`): TODO | ||
263 | sample_hat (`torch.FloatTensor`): TODO | ||
264 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
265 | |||
266 | EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). | ||
267 | Returns: | ||
268 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: | ||
269 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
270 | returning a tuple, the first element is the sample tensor. | ||
271 | |||
272 | """ | ||
273 | latents = sample | ||
274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | ||
275 | |||
276 | # if callback is not None: | ||
277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
278 | d = to_d(latents, timestep, model_output) | ||
279 | # Euler method | ||
280 | dt = sigma_down - timestep | ||
281 | latents = latents + d * dt | ||
282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | ||
283 | generator=generator) * sigma_up | ||
284 | return SchedulerOutput(prev_sample=latents) | ||
285 | |||
286 | def step_correct( | ||
287 | self, | ||
288 | model_output: torch.FloatTensor, | ||
289 | sigma_hat: float, | ||
290 | sigma_prev: float, | ||
291 | sample_hat: torch.FloatTensor, | ||
292 | sample_prev: torch.FloatTensor, | ||
293 | derivative: torch.FloatTensor, | ||
294 | generator: None, | ||
295 | return_dict: bool = True, | ||
296 | ) -> Union[SchedulerOutput, Tuple]: | ||
297 | """ | ||
298 | Correct the predicted sample based on the output model_output of the network. TODO complete description | ||
299 | |||
300 | Args: | ||
301 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
302 | sigma_hat (`float`): TODO | ||
303 | sigma_prev (`float`): TODO | ||
304 | sample_hat (`torch.FloatTensor`): TODO | ||
305 | sample_prev (`torch.FloatTensor`): TODO | ||
306 | derivative (`torch.FloatTensor`): TODO | ||
307 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
308 | |||
309 | Returns: | ||
310 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO | ||
311 | |||
312 | """ | ||
313 | pred_original_sample = sample_prev + sigma_prev * model_output | ||
314 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev | ||
315 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) | ||
316 | |||
317 | if not return_dict: | ||
318 | return (sample_prev, derivative) | ||
319 | |||
320 | return SchedulerOutput(prev_sample=sample_prev) | ||
321 | |||
322 | def add_noise(self, original_samples, noise, timesteps): | ||
323 | raise NotImplementedError() | ||