summaryrefslogtreecommitdiffstats
path: root/schedulers/scheduling_euler_a.py
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers/scheduling_euler_a.py')
-rw-r--r--schedulers/scheduling_euler_a.py210
1 files changed, 92 insertions, 118 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 13ea6b3..6abe971 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput 7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
8 8
9 9
10'''
11helper functions: append_zero(),
12 t_to_sigma(),
13 get_sigmas(),
14 append_dims(),
15 CFGDenoiserForward(),
16 get_scalings(),
17 DSsigma_to_t(),
18 DiscreteEpsDDPMDenoiserForward(),
19 to_d(),
20 get_ancestral_step()
21need cleaning
22'''
23
24
25def append_zero(x):
26 return torch.cat([x, x.new_zeros([1])])
27
28
29def t_to_sigma(t, sigmas):
30 t = t.float()
31 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
32 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
33
34
35def get_sigmas(sigmas, n=None):
36 if n is None:
37 return append_zero(sigmas.flip(0))
38 t_max = len(sigmas) - 1 # = 999
39 t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype)
40 return append_zero(t_to_sigma(t, sigmas))
41
42# from k_samplers utils.py
43
44
45def append_dims(x, target_dims):
46 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
47 dims_to_append = target_dims - x.ndim
48 if dims_to_append < 0:
49 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
50 return x[(...,) + (None,) * dims_to_append]
51
52
53def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None):
54 # x_in = torch.cat([x] * 2)#A# concat the latent
55 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
56 # cond_in = torch.cat([uncond, cond])
57 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
58 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
59 # return uncond + (cond - uncond) * cond_scale
60 noise_pred = DiscreteEpsDDPMDenoiserForward(
61 Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in)
62 return noise_pred
63
64# from k_samplers sampling.py
65
66
67def to_d(x, sigma, denoised):
68 """Converts a denoiser output to a Karras ODE derivative."""
69 return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim)
70
71
72def get_scalings(sigma):
73 sigma_data = 1.
74 c_out = -sigma
75 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
76 return c_out, c_in
77
78# DiscreteSchedule DS
79
80
81def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
82 dists = torch.abs(sigma - DSsigmas[:, None])
83 if quantize:
84 return torch.argmin(dists, dim=0).view(sigma.shape)
85 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
86 low, high = DSsigmas[low_idx], DSsigmas[high_idx]
87 w = (low - sigma) / (low - high)
88 w = w.clamp(0, 1)
89 t = (1 - w) * low_idx + w * high_idx
90 return t.view(sigma.shape)
91
92
93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
94 sigma = sigma.to(dtype=input.dtype, device=Unet.device)
95 DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device)
96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
97 # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}")
98 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
99 encoder_hidden_states=kwargs['cond']).sample
100 return input + eps * c_out
101
102
103# from k_samplers sampling.py
104def get_ancestral_step(sigma_from, sigma_to):
105 """Calculates the noise level (sigma_down) to step down to and the amount
106 of noise to add (sigma_up) when doing an ancestral sampling step."""
107 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
108 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
109 return sigma_down, sigma_up
110
111
112'''
113Euler Ancestral Scheduler
114'''
115
116
117class EulerAScheduler(SchedulerMixin, ConfigMixin): 10class EulerAScheduler(SchedulerMixin, ConfigMixin):
118 """ 11 """
119 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and 12 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
@@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
154 beta_end: float = 0.02, 47 beta_end: float = 0.02,
155 beta_schedule: str = "linear", 48 beta_schedule: str = "linear",
156 trained_betas: Optional[np.ndarray] = None, 49 trained_betas: Optional[np.ndarray] = None,
50 tensor_format: str = "pt",
51 num_inference_steps=None,
52 device='cuda'
157 ): 53 ):
158 if trained_betas is not None: 54 if trained_betas is not None:
159 self.betas = torch.from_numpy(trained_betas) 55 self.betas = torch.from_numpy(trained_betas).to(device)
160 if beta_schedule == "linear": 56 if beta_schedule == "linear":
161 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 57 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device)
162 elif beta_schedule == "scaled_linear": 58 elif beta_schedule == "scaled_linear":
163 # this schedule is very specific to the latent diffusion model. 59 # this schedule is very specific to the latent diffusion model.
164 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 60 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps,
165 elif beta_schedule == "squaredcos_cap_v2": 61 dtype=torch.float32, device=device) ** 2
166 # Glide cosine schedule
167 self.betas = betas_for_alpha_bar(num_train_timesteps)
168 else: 62 else:
169 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 63 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
170 64
65 self.device = device
66 self.tensor_format = tensor_format
67
171 self.alphas = 1.0 - self.betas 68 self.alphas = 1.0 - self.betas
172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 69 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
173 70
@@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
175 self.init_noise_sigma = 1.0 72 self.init_noise_sigma = 1.0
176 73
177 # setable values 74 # setable values
178 self.num_inference_steps = None 75 self.num_inference_steps = num_inference_steps
179 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 76 self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
77 # get sigmas
78 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
79 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
80 self.set_format(tensor_format=tensor_format)
180 81
181 # A# take number of steps as input 82 # A# take number of steps as input
182 # A# store 1) number of steps 2) timesteps 3) schedule 83 # A# store 1) number of steps 2) timesteps 3) schedule
@@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
192 93
193 self.num_inference_steps = num_inference_steps 94 self.num_inference_steps = num_inference_steps
194 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 95 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
195 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 96 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
196 self.timesteps = self.sigmas[:-1] 97 self.timesteps = self.sigmas[:-1]
197 self.is_scale_input_called = False 98 self.is_scale_input_called = False
198 99
@@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
251 s_prev = self.sigmas[step_prev_index] 152 s_prev = self.sigmas[step_prev_index]
252 latents = sample 153 latents = sample
253 154
254 sigma_down, sigma_up = get_ancestral_step(s, s_prev) 155 sigma_down, sigma_up = self.get_ancestral_step(s, s_prev)
255 d = to_d(latents, s, model_output) 156 d = self.to_d(latents, s, model_output)
256 dt = sigma_down - s 157 dt = sigma_down - s
257 latents = latents + d * dt 158 latents = latents + d * dt
258 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, 159 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
@@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
313 noisy_samples = original_samples + noise * sigma 214 noisy_samples = original_samples + noise * sigma
314 self.is_scale_input_called = True 215 self.is_scale_input_called = True
315 return noisy_samples 216 return noisy_samples
217
218 # from k_samplers sampling.py
219
220 def get_ancestral_step(self, sigma_from, sigma_to):
221 """Calculates the noise level (sigma_down) to step down to and the amount
222 of noise to add (sigma_up) when doing an ancestral sampling step."""
223 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
224 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
225 return sigma_down, sigma_up
226
227 def t_to_sigma(self, t, sigmas):
228 t = t.float()
229 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
230 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
231
232 def append_zero(self, x):
233 return torch.cat([x, x.new_zeros([1])])
234
235 def get_sigmas(self, sigmas, n=None):
236 if n is None:
237 return self.append_zero(sigmas.flip(0))
238 t_max = len(sigmas) - 1 # = 999
239 device = self.device
240 t = torch.linspace(t_max, 0, n, device=device)
241 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
242 return self.append_zero(self.t_to_sigma(t, sigmas))
243
244 # from k_samplers utils.py
245 def append_dims(self, x, target_dims):
246 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
247 dims_to_append = target_dims - x.ndim
248 if dims_to_append < 0:
249 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
250 return x[(...,) + (None,) * dims_to_append]
251
252 # from k_samplers sampling.py
253 def to_d(self, x, sigma, denoised):
254 """Converts a denoiser output to a Karras ODE derivative."""
255 return (x - denoised) / self.append_dims(sigma, x.ndim)
256
257 def get_scalings(self, sigma):
258 sigma_data = 1.
259 c_out = -sigma
260 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
261 return c_out, c_in
262
263 # DiscreteSchedule DS
264 def DSsigma_to_t(self, sigma, quantize=None):
265 # quantize = self.quantize if quantize is None else quantize
266 quantize = False
267 dists = torch.abs(sigma - self.DSsigmas[:, None])
268 if quantize:
269 return torch.argmin(dists, dim=0).view(sigma.shape)
270 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
271 low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
272 w = (low - sigma) / (low - high)
273 w = w.clamp(0, 1)
274 t = (1 - w) * low_idx + w * high_idx
275 return t.view(sigma.shape)
276
277 def prepare_input(self, latent_in, t, batch_size):
278 sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1
279
280 sigma_in = torch.cat([sigma] * 2 * batch_size)
281 # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
282 # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
283 c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
284
285 sigma_in = self.DSsigma_to_t(sigma_in)
286 # s_in = latent_in.new_ones([latent_in.shape[0]])
287 # sigma_in = sigma_in * s_in
288
289 return c_out, c_in, sigma_in