diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): | |||
| 41 | return append_zero(sigmas.flip(0)) | 41 | return append_zero(sigmas.flip(0)) |
| 42 | t_max = len(sigmas) - 1 # = 999 | 42 | t_max = len(sigmas) - 1 # = 999 |
| 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) |
| 44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 45 | return append_zero(t_to_sigma(t, sigmas)) | 44 | return append_zero(t_to_sigma(t, sigmas)) |
| 46 | 45 | ||
| 47 | # from k_samplers utils.py | 46 | # from k_samplers utils.py |
| @@ -55,14 +54,15 @@ def append_dims(x, target_dims): | |||
| 55 | return x[(...,) + (None,) * dims_to_append] | 54 | return x[(...,) + (None,) * dims_to_append] |
| 56 | 55 | ||
| 57 | 56 | ||
| 58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | 57 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): |
| 59 | # x_in = torch.cat([x] * 2)#A# concat the latent | 58 | # x_in = torch.cat([x] * 2)#A# concat the latent |
| 60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | 59 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma |
| 61 | # cond_in = torch.cat([uncond, cond]) | 60 | # cond_in = torch.cat([uncond, cond]) |
| 62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | 61 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
| 63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | 62 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) |
| 64 | # return uncond + (cond - uncond) * cond_scale | 63 | # return uncond + (cond - uncond) * cond_scale |
| 65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | 64 | noise_pred = DiscreteEpsDDPMDenoiserForward( |
| 65 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
| 66 | return noise_pred | 66 | return noise_pred |
| 67 | 67 | ||
| 68 | # from k_samplers sampling.py | 68 | # from k_samplers sampling.py |
| @@ -82,9 +82,7 @@ def get_scalings(sigma): | |||
| 82 | # DiscreteSchedule DS | 82 | # DiscreteSchedule DS |
| 83 | 83 | ||
| 84 | 84 | ||
| 85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | 85 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): |
| 86 | # quantize = self.quantize if quantize is None else quantize | ||
| 87 | quantize = False | ||
| 88 | dists = torch.abs(sigma - DSsigmas[:, None]) | 86 | dists = torch.abs(sigma - DSsigmas[:, None]) |
| 89 | if quantize: | 87 | if quantize: |
| 90 | return torch.argmin(dists, dim=0).view(sigma.shape) | 88 | return torch.argmin(dists, dim=0).view(sigma.shape) |
| @@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | |||
| 96 | return t.view(sigma.shape) | 94 | return t.view(sigma.shape) |
| 97 | 95 | ||
| 98 | 96 | ||
| 99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | 97 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
| 100 | sigma = sigma.to(Unet.device) | 98 | sigma = sigma.to(Unet.device) |
| 101 | DSsigmas = DSsigmas.to(Unet.device) | 99 | DSsigmas = DSsigmas.to(Unet.device) |
| 102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 100 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
| 103 | # ??? what is eps? | 101 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
| 104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
| 105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
| 106 | encoder_hidden_states=kwargs['cond']).sample | 102 | encoder_hidden_states=kwargs['cond']).sample |
| 107 | return input + eps * c_out | 103 | return input + eps * c_out |
| 108 | 104 | ||
