diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): | |||
41 | return append_zero(sigmas.flip(0)) | 41 | return append_zero(sigmas.flip(0)) |
42 | t_max = len(sigmas) - 1 # = 999 | 42 | t_max = len(sigmas) - 1 # = 999 |
43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) |
44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
45 | return append_zero(t_to_sigma(t, sigmas)) | 44 | return append_zero(t_to_sigma(t, sigmas)) |
46 | 45 | ||
47 | # from k_samplers utils.py | 46 | # from k_samplers utils.py |
@@ -55,14 +54,15 @@ def append_dims(x, target_dims): | |||
55 | return x[(...,) + (None,) * dims_to_append] | 54 | return x[(...,) + (None,) * dims_to_append] |
56 | 55 | ||
57 | 56 | ||
58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | 57 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): |
59 | # x_in = torch.cat([x] * 2)#A# concat the latent | 58 | # x_in = torch.cat([x] * 2)#A# concat the latent |
60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | 59 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma |
61 | # cond_in = torch.cat([uncond, cond]) | 60 | # cond_in = torch.cat([uncond, cond]) |
62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | 61 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | 62 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) |
64 | # return uncond + (cond - uncond) * cond_scale | 63 | # return uncond + (cond - uncond) * cond_scale |
65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | 64 | noise_pred = DiscreteEpsDDPMDenoiserForward( |
65 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
66 | return noise_pred | 66 | return noise_pred |
67 | 67 | ||
68 | # from k_samplers sampling.py | 68 | # from k_samplers sampling.py |
@@ -82,9 +82,7 @@ def get_scalings(sigma): | |||
82 | # DiscreteSchedule DS | 82 | # DiscreteSchedule DS |
83 | 83 | ||
84 | 84 | ||
85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | 85 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): |
86 | # quantize = self.quantize if quantize is None else quantize | ||
87 | quantize = False | ||
88 | dists = torch.abs(sigma - DSsigmas[:, None]) | 86 | dists = torch.abs(sigma - DSsigmas[:, None]) |
89 | if quantize: | 87 | if quantize: |
90 | return torch.argmin(dists, dim=0).view(sigma.shape) | 88 | return torch.argmin(dists, dim=0).view(sigma.shape) |
@@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | |||
96 | return t.view(sigma.shape) | 94 | return t.view(sigma.shape) |
97 | 95 | ||
98 | 96 | ||
99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | 97 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
100 | sigma = sigma.to(Unet.device) | 98 | sigma = sigma.to(Unet.device) |
101 | DSsigmas = DSsigmas.to(Unet.device) | 99 | DSsigmas = DSsigmas.to(Unet.device) |
102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 100 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
103 | # ??? what is eps? | 101 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
106 | encoder_hidden_states=kwargs['cond']).sample | 102 | encoder_hidden_states=kwargs['cond']).sample |
107 | return input + eps * c_out | 103 | return input + eps * c_out |
108 | 104 | ||