From f5b656d21c5b449eed6ce212e909043c124f79ee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 12 Oct 2022 08:18:22 +0200 Subject: Various updates --- schedulers/scheduling_euler_a.py | 210 +++++++++++++++++---------------------- 1 file changed, 92 insertions(+), 118 deletions(-) (limited to 'schedulers/scheduling_euler_a.py') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 13ea6b3..6abe971 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -''' -helper functions: append_zero(), - t_to_sigma(), - get_sigmas(), - append_dims(), - CFGDenoiserForward(), - get_scalings(), - DSsigma_to_t(), - DiscreteEpsDDPMDenoiserForward(), - to_d(), - get_ancestral_step() -need cleaning -''' - - -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - -def t_to_sigma(t, sigmas): - t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() - return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] - - -def get_sigmas(sigmas, n=None): - if n is None: - return append_zero(sigmas.flip(0)) - t_max = len(sigmas) - 1 # = 999 - t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) - return append_zero(t_to_sigma(t, sigmas)) - -# from k_samplers utils.py - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - -def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): - # x_in = torch.cat([x] * 2)#A# concat the latent - # sigma_in = torch.cat([sigma] * 2) #A# concat sigma - # cond_in = torch.cat([uncond, cond]) - # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) - # return uncond + (cond - uncond) * cond_scale - noise_pred = DiscreteEpsDDPMDenoiserForward( - Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) - return noise_pred - -# from k_samplers sampling.py - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) - - -def get_scalings(sigma): - sigma_data = 1. - c_out = -sigma - c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 - return c_out, c_in - -# DiscreteSchedule DS - - -def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): - dists = torch.abs(sigma - DSsigmas[:, None]) - if quantize: - return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] - low, high = DSsigmas[low_idx], DSsigmas[high_idx] - w = (low - sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - -def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): - sigma = sigma.to(dtype=input.dtype, device=Unet.device) - DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) - c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] - # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") - eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), - encoder_hidden_states=kwargs['cond']).sample - return input + eps * c_out - - -# from k_samplers sampling.py -def get_ancestral_step(sigma_from, sigma_to): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - return sigma_down, sigma_up - - -''' -Euler Ancestral Scheduler -''' - - class EulerAScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and @@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + num_inference_steps=None, + device='cuda' ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.from_numpy(trained_betas).to(device) if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, + dtype=torch.float32, device=device) ** 2 else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.device = device + self.tensor_format = tensor_format + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + # get sigmas + self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) + self.set_format(tensor_format=tensor_format) # A# take number of steps as input # A# store 1) number of steps 2) timesteps 3) schedule @@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) + self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) self.timesteps = self.sigmas[:-1] self.is_scale_input_called = False @@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): s_prev = self.sigmas[step_prev_index] latents = sample - sigma_down, sigma_up = get_ancestral_step(s, s_prev) - d = to_d(latents, s, model_output) + sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) + d = self.to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, @@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): noisy_samples = original_samples + noise * sigma self.is_scale_input_called = True return noisy_samples + + # from k_samplers sampling.py + + def get_ancestral_step(self, sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + def t_to_sigma(self, t, sigmas): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] + + def append_zero(self, x): + return torch.cat([x, x.new_zeros([1])]) + + def get_sigmas(self, sigmas, n=None): + if n is None: + return self.append_zero(sigmas.flip(0)) + t_max = len(sigmas) - 1 # = 999 + device = self.device + t = torch.linspace(t_max, 0, n, device=device) + # t = torch.linspace(t_max, 0, n, device=sigmas.device) + return self.append_zero(self.t_to_sigma(t, sigmas)) + + # from k_samplers utils.py + def append_dims(self, x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + # from k_samplers sampling.py + def to_d(self, x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / self.append_dims(sigma, x.ndim) + + def get_scalings(self, sigma): + sigma_data = 1. + c_out = -sigma + c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 + return c_out, c_in + + # DiscreteSchedule DS + def DSsigma_to_t(self, sigma, quantize=None): + # quantize = self.quantize if quantize is None else quantize + quantize = False + dists = torch.abs(sigma - self.DSsigmas[:, None]) + if quantize: + return torch.argmin(dists, dim=0).view(sigma.shape) + low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] + w = (low - sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def prepare_input(self, latent_in, t, batch_size): + sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 + + sigma_in = torch.cat([sigma] * 2 * batch_size) + # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) + # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) + c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] + + sigma_in = self.DSsigma_to_t(sigma_in) + # s_in = latent_in.new_ones([latent_in.shape[0]]) + # sigma_in = sigma_in * s_in + + return c_out, c_in, sigma_in -- cgit v1.2.3-54-g00ecf