diff options
author | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
commit | 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch) | |
tree | 8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /schedulers | |
parent | Inference: Add support for embeddings (diff) | |
download | textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2 textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip |
Update
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 45 |
1 files changed, 38 insertions, 7 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index c6436d8..13ea6b3 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -171,6 +171,9 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
171 | self.alphas = 1.0 - self.betas | 171 | self.alphas = 1.0 - self.betas |
172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
173 | 173 | ||
174 | # standard deviation of the initial noise distribution | ||
175 | self.init_noise_sigma = 1.0 | ||
176 | |||
174 | # setable values | 177 | # setable values |
175 | self.num_inference_steps = None | 178 | self.num_inference_steps = None |
176 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 179 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] |
@@ -190,13 +193,33 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
190 | self.num_inference_steps = num_inference_steps | 193 | self.num_inference_steps = num_inference_steps |
191 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 194 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
192 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 195 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
193 | self.timesteps = np.arange(0, self.num_inference_steps) | 196 | self.timesteps = self.sigmas[:-1] |
197 | self.is_scale_input_called = False | ||
198 | |||
199 | def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: | ||
200 | """ | ||
201 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | ||
202 | current timestep. | ||
203 | Args: | ||
204 | sample (`torch.FloatTensor`): input sample | ||
205 | timestep (`int`, optional): current timestep | ||
206 | Returns: | ||
207 | `torch.FloatTensor`: scaled input sample | ||
208 | """ | ||
209 | if isinstance(timestep, torch.Tensor): | ||
210 | timestep = timestep.to(self.timesteps.device) | ||
211 | if self.is_scale_input_called: | ||
212 | return sample | ||
213 | step_index = (self.timesteps == timestep).nonzero().item() | ||
214 | sigma = self.sigmas[step_index] | ||
215 | sample = sample * sigma | ||
216 | self.is_scale_input_called = True | ||
217 | return sample | ||
194 | 218 | ||
195 | def step( | 219 | def step( |
196 | self, | 220 | self, |
197 | model_output: torch.FloatTensor, | 221 | model_output: torch.FloatTensor, |
198 | timestep: int, | 222 | timestep: Union[float, torch.FloatTensor], |
199 | timestep_prev: int, | ||
200 | sample: torch.FloatTensor, | 223 | sample: torch.FloatTensor, |
201 | generator: torch.Generator = None, | 224 | generator: torch.Generator = None, |
202 | return_dict: bool = True, | 225 | return_dict: bool = True, |
@@ -219,8 +242,13 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
219 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. |
220 | 243 | ||
221 | """ | 244 | """ |
222 | s = self.sigmas[timestep] | 245 | if isinstance(timestep, torch.Tensor): |
223 | s_prev = self.sigmas[timestep_prev] | 246 | timestep = timestep.to(self.timesteps.device) |
247 | step_index = (self.timesteps == timestep).nonzero().item() | ||
248 | step_prev_index = step_index + 1 | ||
249 | |||
250 | s = self.sigmas[step_index] | ||
251 | s_prev = self.sigmas[step_prev_index] | ||
224 | latents = sample | 252 | latents = sample |
225 | 253 | ||
226 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) | 254 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) |
@@ -271,14 +299,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
271 | self, | 299 | self, |
272 | original_samples: torch.FloatTensor, | 300 | original_samples: torch.FloatTensor, |
273 | noise: torch.FloatTensor, | 301 | noise: torch.FloatTensor, |
274 | timesteps: torch.IntTensor, | 302 | timesteps: torch.FloatTensor, |
275 | ) -> torch.FloatTensor: | 303 | ) -> torch.FloatTensor: |
276 | sigmas = self.sigmas.to(original_samples.device) | 304 | sigmas = self.sigmas.to(original_samples.device) |
305 | schedule_timesteps = self.timesteps.to(original_samples.device) | ||
277 | timesteps = timesteps.to(original_samples.device) | 306 | timesteps = timesteps.to(original_samples.device) |
307 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
278 | 308 | ||
279 | sigma = sigmas[timesteps].flatten() | 309 | sigma = sigmas[step_indices].flatten() |
280 | while len(sigma.shape) < len(original_samples.shape): | 310 | while len(sigma.shape) < len(original_samples.shape): |
281 | sigma = sigma.unsqueeze(-1) | 311 | sigma = sigma.unsqueeze(-1) |
282 | 312 | ||
283 | noisy_samples = original_samples + noise * sigma | 313 | noisy_samples = original_samples + noise * sigma |
314 | self.is_scale_input_called = True | ||
284 | return noisy_samples | 315 | return noisy_samples |