From 6720c99f7082dc855059ad4afd6b3cb45b62bc1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Oct 2022 16:53:19 +0200 Subject: Fix seed, better progress bar, fix euler_a for batch size > 1 --- schedulers/scheduling_euler_a.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): return append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 t = torch.linspace(t_max, 0, n, device=sigmas.device) - # t = torch.linspace(t_max, 0, n, device=sigmas.device) return append_zero(t_to_sigma(t, sigmas)) # from k_samplers utils.py @@ -55,14 +54,15 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): +def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): # x_in = torch.cat([x] * 2)#A# concat the latent # sigma_in = torch.cat([sigma] * 2) #A# concat sigma # cond_in = torch.cat([uncond, cond]) # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) # return uncond + (cond - uncond) * cond_scale - noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) + noise_pred = DiscreteEpsDDPMDenoiserForward( + Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) return noise_pred # from k_samplers sampling.py @@ -82,9 +82,7 @@ def get_scalings(sigma): # DiscreteSchedule DS -def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): - # quantize = self.quantize if quantize is None else quantize - quantize = False +def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): dists = torch.abs(sigma - DSsigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) @@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): return t.view(sigma.shape) -def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): +def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): sigma = sigma.to(Unet.device) DSsigmas = DSsigmas.to(Unet.device) c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] - # ??? what is eps? - # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) - eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), + eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample return input + eps * c_out -- cgit v1.2.3-70-g09d2