diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-12 08:18:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-12 08:18:22 +0200 |
| commit | f5b656d21c5b449eed6ce212e909043c124f79ee (patch) | |
| tree | 905f20900433f1e77840cd66417395168e0eec7f /schedulers | |
| parent | Added EMA support to Textual Inversion (diff) | |
| download | textual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.tar.gz textual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.tar.bz2 textual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.zip | |
Various updates
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 210 |
1 files changed, 92 insertions, 118 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 13ea6b3..6abe971 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config | |||
| 7 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | 7 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput |
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | ''' | ||
| 11 | helper functions: append_zero(), | ||
| 12 | t_to_sigma(), | ||
| 13 | get_sigmas(), | ||
| 14 | append_dims(), | ||
| 15 | CFGDenoiserForward(), | ||
| 16 | get_scalings(), | ||
| 17 | DSsigma_to_t(), | ||
| 18 | DiscreteEpsDDPMDenoiserForward(), | ||
| 19 | to_d(), | ||
| 20 | get_ancestral_step() | ||
| 21 | need cleaning | ||
| 22 | ''' | ||
| 23 | |||
| 24 | |||
| 25 | def append_zero(x): | ||
| 26 | return torch.cat([x, x.new_zeros([1])]) | ||
| 27 | |||
| 28 | |||
| 29 | def t_to_sigma(t, sigmas): | ||
| 30 | t = t.float() | ||
| 31 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 32 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 33 | |||
| 34 | |||
| 35 | def get_sigmas(sigmas, n=None): | ||
| 36 | if n is None: | ||
| 37 | return append_zero(sigmas.flip(0)) | ||
| 38 | t_max = len(sigmas) - 1 # = 999 | ||
| 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) | ||
| 40 | return append_zero(t_to_sigma(t, sigmas)) | ||
| 41 | |||
| 42 | # from k_samplers utils.py | ||
| 43 | |||
| 44 | |||
| 45 | def append_dims(x, target_dims): | ||
| 46 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 47 | dims_to_append = target_dims - x.ndim | ||
| 48 | if dims_to_append < 0: | ||
| 49 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 50 | return x[(...,) + (None,) * dims_to_append] | ||
| 51 | |||
| 52 | |||
| 53 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): | ||
| 54 | # x_in = torch.cat([x] * 2)#A# concat the latent | ||
| 55 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | ||
| 56 | # cond_in = torch.cat([uncond, cond]) | ||
| 57 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | ||
| 58 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | ||
| 59 | # return uncond + (cond - uncond) * cond_scale | ||
| 60 | noise_pred = DiscreteEpsDDPMDenoiserForward( | ||
| 61 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
| 62 | return noise_pred | ||
| 63 | |||
| 64 | # from k_samplers sampling.py | ||
| 65 | |||
| 66 | |||
| 67 | def to_d(x, sigma, denoised): | ||
| 68 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 69 | return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) | ||
| 70 | |||
| 71 | |||
| 72 | def get_scalings(sigma): | ||
| 73 | sigma_data = 1. | ||
| 74 | c_out = -sigma | ||
| 75 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 76 | return c_out, c_in | ||
| 77 | |||
| 78 | # DiscreteSchedule DS | ||
| 79 | |||
| 80 | |||
| 81 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): | ||
| 82 | dists = torch.abs(sigma - DSsigmas[:, None]) | ||
| 83 | if quantize: | ||
| 84 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 85 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 86 | low, high = DSsigmas[low_idx], DSsigmas[high_idx] | ||
| 87 | w = (low - sigma) / (low - high) | ||
| 88 | w = w.clamp(0, 1) | ||
| 89 | t = (1 - w) * low_idx + w * high_idx | ||
| 90 | return t.view(sigma.shape) | ||
| 91 | |||
| 92 | |||
| 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): | ||
| 94 | sigma = sigma.to(dtype=input.dtype, device=Unet.device) | ||
| 95 | DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) | ||
| 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | ||
| 97 | # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") | ||
| 98 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), | ||
| 99 | encoder_hidden_states=kwargs['cond']).sample | ||
| 100 | return input + eps * c_out | ||
| 101 | |||
| 102 | |||
| 103 | # from k_samplers sampling.py | ||
| 104 | def get_ancestral_step(sigma_from, sigma_to): | ||
| 105 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 106 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 107 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 108 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 109 | return sigma_down, sigma_up | ||
| 110 | |||
| 111 | |||
| 112 | ''' | ||
| 113 | Euler Ancestral Scheduler | ||
| 114 | ''' | ||
| 115 | |||
| 116 | |||
| 117 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | 10 | class EulerAScheduler(SchedulerMixin, ConfigMixin): |
| 118 | """ | 11 | """ |
| 119 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | 12 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and |
| @@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 154 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
| 155 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
| 156 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
| 50 | tensor_format: str = "pt", | ||
| 51 | num_inference_steps=None, | ||
| 52 | device='cuda' | ||
| 157 | ): | 53 | ): |
| 158 | if trained_betas is not None: | 54 | if trained_betas is not None: |
| 159 | self.betas = torch.from_numpy(trained_betas) | 55 | self.betas = torch.from_numpy(trained_betas).to(device) |
| 160 | if beta_schedule == "linear": | 56 | if beta_schedule == "linear": |
| 161 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | 57 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) |
| 162 | elif beta_schedule == "scaled_linear": | 58 | elif beta_schedule == "scaled_linear": |
| 163 | # this schedule is very specific to the latent diffusion model. | 59 | # this schedule is very specific to the latent diffusion model. |
| 164 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | 60 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, |
| 165 | elif beta_schedule == "squaredcos_cap_v2": | 61 | dtype=torch.float32, device=device) ** 2 |
| 166 | # Glide cosine schedule | ||
| 167 | self.betas = betas_for_alpha_bar(num_train_timesteps) | ||
| 168 | else: | 62 | else: |
| 169 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 170 | 64 | ||
| 65 | self.device = device | ||
| 66 | self.tensor_format = tensor_format | ||
| 67 | |||
| 171 | self.alphas = 1.0 - self.betas | 68 | self.alphas = 1.0 - self.betas |
| 172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| 173 | 70 | ||
| @@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 175 | self.init_noise_sigma = 1.0 | 72 | self.init_noise_sigma = 1.0 |
| 176 | 73 | ||
| 177 | # setable values | 74 | # setable values |
| 178 | self.num_inference_steps = None | 75 | self.num_inference_steps = num_inference_steps |
| 179 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 76 | self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() |
| 77 | # get sigmas | ||
| 78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
| 79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | ||
| 80 | self.set_format(tensor_format=tensor_format) | ||
| 180 | 81 | ||
| 181 | # A# take number of steps as input | 82 | # A# take number of steps as input |
| 182 | # A# store 1) number of steps 2) timesteps 3) schedule | 83 | # A# store 1) number of steps 2) timesteps 3) schedule |
| @@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 192 | 93 | ||
| 193 | self.num_inference_steps = num_inference_steps | 94 | self.num_inference_steps = num_inference_steps |
| 194 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 95 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 195 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 96 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
| 196 | self.timesteps = self.sigmas[:-1] | 97 | self.timesteps = self.sigmas[:-1] |
| 197 | self.is_scale_input_called = False | 98 | self.is_scale_input_called = False |
| 198 | 99 | ||
| @@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 251 | s_prev = self.sigmas[step_prev_index] | 152 | s_prev = self.sigmas[step_prev_index] |
| 252 | latents = sample | 153 | latents = sample |
| 253 | 154 | ||
| 254 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) | 155 | sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) |
| 255 | d = to_d(latents, s, model_output) | 156 | d = self.to_d(latents, s, model_output) |
| 256 | dt = sigma_down - s | 157 | dt = sigma_down - s |
| 257 | latents = latents + d * dt | 158 | latents = latents + d * dt |
| 258 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, | 159 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, |
| @@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 313 | noisy_samples = original_samples + noise * sigma | 214 | noisy_samples = original_samples + noise * sigma |
| 314 | self.is_scale_input_called = True | 215 | self.is_scale_input_called = True |
| 315 | return noisy_samples | 216 | return noisy_samples |
| 217 | |||
| 218 | # from k_samplers sampling.py | ||
| 219 | |||
| 220 | def get_ancestral_step(self, sigma_from, sigma_to): | ||
| 221 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 222 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 223 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 224 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 225 | return sigma_down, sigma_up | ||
| 226 | |||
| 227 | def t_to_sigma(self, t, sigmas): | ||
| 228 | t = t.float() | ||
| 229 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 230 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 231 | |||
| 232 | def append_zero(self, x): | ||
| 233 | return torch.cat([x, x.new_zeros([1])]) | ||
| 234 | |||
| 235 | def get_sigmas(self, sigmas, n=None): | ||
| 236 | if n is None: | ||
| 237 | return self.append_zero(sigmas.flip(0)) | ||
| 238 | t_max = len(sigmas) - 1 # = 999 | ||
| 239 | device = self.device | ||
| 240 | t = torch.linspace(t_max, 0, n, device=device) | ||
| 241 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 242 | return self.append_zero(self.t_to_sigma(t, sigmas)) | ||
| 243 | |||
| 244 | # from k_samplers utils.py | ||
| 245 | def append_dims(self, x, target_dims): | ||
| 246 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 247 | dims_to_append = target_dims - x.ndim | ||
| 248 | if dims_to_append < 0: | ||
| 249 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 250 | return x[(...,) + (None,) * dims_to_append] | ||
| 251 | |||
| 252 | # from k_samplers sampling.py | ||
| 253 | def to_d(self, x, sigma, denoised): | ||
| 254 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 255 | return (x - denoised) / self.append_dims(sigma, x.ndim) | ||
| 256 | |||
| 257 | def get_scalings(self, sigma): | ||
| 258 | sigma_data = 1. | ||
| 259 | c_out = -sigma | ||
| 260 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 261 | return c_out, c_in | ||
| 262 | |||
| 263 | # DiscreteSchedule DS | ||
| 264 | def DSsigma_to_t(self, sigma, quantize=None): | ||
| 265 | # quantize = self.quantize if quantize is None else quantize | ||
| 266 | quantize = False | ||
| 267 | dists = torch.abs(sigma - self.DSsigmas[:, None]) | ||
| 268 | if quantize: | ||
| 269 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 270 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 271 | low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] | ||
| 272 | w = (low - sigma) / (low - high) | ||
| 273 | w = w.clamp(0, 1) | ||
| 274 | t = (1 - w) * low_idx + w * high_idx | ||
| 275 | return t.view(sigma.shape) | ||
| 276 | |||
| 277 | def prepare_input(self, latent_in, t, batch_size): | ||
| 278 | sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 | ||
| 279 | |||
| 280 | sigma_in = torch.cat([sigma] * 2 * batch_size) | ||
| 281 | # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) | ||
| 282 | # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) | ||
| 283 | c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] | ||
| 284 | |||
| 285 | sigma_in = self.DSsigma_to_t(sigma_in) | ||
| 286 | # s_in = latent_in.new_ones([latent_in.shape[0]]) | ||
| 287 | # sigma_in = sigma_in * s_in | ||
| 288 | |||
| 289 | return c_out, c_in, sigma_in | ||
