summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
commit49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch)
tree8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /schedulers
parentInference: Add support for embeddings (diff)
downloadtextual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip
Update
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py45
1 files changed, 38 insertions, 7 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index c6436d8..13ea6b3 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -171,6 +171,9 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
171 self.alphas = 1.0 - self.betas 171 self.alphas = 1.0 - self.betas
172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
173 173
174 # standard deviation of the initial noise distribution
175 self.init_noise_sigma = 1.0
176
174 # setable values 177 # setable values
175 self.num_inference_steps = None 178 self.num_inference_steps = None
176 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 179 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
@@ -190,13 +193,33 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
190 self.num_inference_steps = num_inference_steps 193 self.num_inference_steps = num_inference_steps
191 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 194 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
192 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 195 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
193 self.timesteps = np.arange(0, self.num_inference_steps) 196 self.timesteps = self.sigmas[:-1]
197 self.is_scale_input_called = False
198
199 def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor:
200 """
201 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
202 current timestep.
203 Args:
204 sample (`torch.FloatTensor`): input sample
205 timestep (`int`, optional): current timestep
206 Returns:
207 `torch.FloatTensor`: scaled input sample
208 """
209 if isinstance(timestep, torch.Tensor):
210 timestep = timestep.to(self.timesteps.device)
211 if self.is_scale_input_called:
212 return sample
213 step_index = (self.timesteps == timestep).nonzero().item()
214 sigma = self.sigmas[step_index]
215 sample = sample * sigma
216 self.is_scale_input_called = True
217 return sample
194 218
195 def step( 219 def step(
196 self, 220 self,
197 model_output: torch.FloatTensor, 221 model_output: torch.FloatTensor,
198 timestep: int, 222 timestep: Union[float, torch.FloatTensor],
199 timestep_prev: int,
200 sample: torch.FloatTensor, 223 sample: torch.FloatTensor,
201 generator: torch.Generator = None, 224 generator: torch.Generator = None,
202 return_dict: bool = True, 225 return_dict: bool = True,
@@ -219,8 +242,13 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
219 returning a tuple, the first element is the sample tensor. 242 returning a tuple, the first element is the sample tensor.
220 243
221 """ 244 """
222 s = self.sigmas[timestep] 245 if isinstance(timestep, torch.Tensor):
223 s_prev = self.sigmas[timestep_prev] 246 timestep = timestep.to(self.timesteps.device)
247 step_index = (self.timesteps == timestep).nonzero().item()
248 step_prev_index = step_index + 1
249
250 s = self.sigmas[step_index]
251 s_prev = self.sigmas[step_prev_index]
224 latents = sample 252 latents = sample
225 253
226 sigma_down, sigma_up = get_ancestral_step(s, s_prev) 254 sigma_down, sigma_up = get_ancestral_step(s, s_prev)
@@ -271,14 +299,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
271 self, 299 self,
272 original_samples: torch.FloatTensor, 300 original_samples: torch.FloatTensor,
273 noise: torch.FloatTensor, 301 noise: torch.FloatTensor,
274 timesteps: torch.IntTensor, 302 timesteps: torch.FloatTensor,
275 ) -> torch.FloatTensor: 303 ) -> torch.FloatTensor:
276 sigmas = self.sigmas.to(original_samples.device) 304 sigmas = self.sigmas.to(original_samples.device)
305 schedule_timesteps = self.timesteps.to(original_samples.device)
277 timesteps = timesteps.to(original_samples.device) 306 timesteps = timesteps.to(original_samples.device)
307 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
278 308
279 sigma = sigmas[timesteps].flatten() 309 sigma = sigmas[step_indices].flatten()
280 while len(sigma.shape) < len(original_samples.shape): 310 while len(sigma.shape) < len(original_samples.shape):
281 sigma = sigma.unsqueeze(-1) 311 sigma = sigma.unsqueeze(-1)
282 312
283 noisy_samples = original_samples + noise * sigma 313 noisy_samples = original_samples + noise * sigma
314 self.is_scale_input_called = True
284 return noisy_samples 315 return noisy_samples