summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py16
1 files changed, 6 insertions, 10 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 29ebd07..9fbedaa 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None):
41 return append_zero(sigmas.flip(0)) 41 return append_zero(sigmas.flip(0))
42 t_max = len(sigmas) - 1 # = 999 42 t_max = len(sigmas) - 1 # = 999
43 t = torch.linspace(t_max, 0, n, device=sigmas.device) 43 t = torch.linspace(t_max, 0, n, device=sigmas.device)
44 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
45 return append_zero(t_to_sigma(t, sigmas)) 44 return append_zero(t_to_sigma(t, sigmas))
46 45
47# from k_samplers utils.py 46# from k_samplers utils.py
@@ -55,14 +54,15 @@ def append_dims(x, target_dims):
55 return x[(...,) + (None,) * dims_to_append] 54 return x[(...,) + (None,) * dims_to_append]
56 55
57 56
58def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): 57def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None):
59 # x_in = torch.cat([x] * 2)#A# concat the latent 58 # x_in = torch.cat([x] * 2)#A# concat the latent
60 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma 59 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
61 # cond_in = torch.cat([uncond, cond]) 60 # cond_in = torch.cat([uncond, cond])
62 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) 61 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
63 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) 62 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
64 # return uncond + (cond - uncond) * cond_scale 63 # return uncond + (cond - uncond) * cond_scale
65 noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) 64 noise_pred = DiscreteEpsDDPMDenoiserForward(
65 Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in)
66 return noise_pred 66 return noise_pred
67 67
68# from k_samplers sampling.py 68# from k_samplers sampling.py
@@ -82,9 +82,7 @@ def get_scalings(sigma):
82# DiscreteSchedule DS 82# DiscreteSchedule DS
83 83
84 84
85def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): 85def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
86 # quantize = self.quantize if quantize is None else quantize
87 quantize = False
88 dists = torch.abs(sigma - DSsigmas[:, None]) 86 dists = torch.abs(sigma - DSsigmas[:, None])
89 if quantize: 87 if quantize:
90 return torch.argmin(dists, dim=0).view(sigma.shape) 88 return torch.argmin(dists, dim=0).view(sigma.shape)
@@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None):
96 return t.view(sigma.shape) 94 return t.view(sigma.shape)
97 95
98 96
99def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): 97def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
100 sigma = sigma.to(Unet.device) 98 sigma = sigma.to(Unet.device)
101 DSsigmas = DSsigmas.to(Unet.device) 99 DSsigmas = DSsigmas.to(Unet.device)
102 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] 100 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
103 # ??? what is eps? 101 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
104 # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
105 eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas),
106 encoder_hidden_states=kwargs['cond']).sample 102 encoder_hidden_states=kwargs['cond']).sample
107 return input + eps * c_out 103 return input + eps * c_out
108 104