summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 15:14:29 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 15:14:29 +0200
commit13b0d9f763269df405d1aeba86213f1c7ce4e7ca (patch)
treeb4b2761032e2ba715dac0cf50adee9ff911d73f6 /schedulers
parentWIP: img2img (diff)
downloadtextual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.tar.gz
textual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.tar.bz2
textual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.zip
More consistent euler_a
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py59
1 files changed, 26 insertions, 33 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 9fbedaa..1b1c9cf 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -1,7 +1,3 @@
1
2
3import math
4import warnings
5from typing import Optional, Tuple, Union 1from typing import Optional, Tuple, Union
6 2
7import numpy as np 3import numpy as np
@@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
157 beta_end: float = 0.02, 153 beta_end: float = 0.02,
158 beta_schedule: str = "linear", 154 beta_schedule: str = "linear",
159 trained_betas: Optional[np.ndarray] = None, 155 trained_betas: Optional[np.ndarray] = None,
160 clip_sample: bool = True,
161 set_alpha_to_one: bool = True,
162 steps_offset: int = 0,
163 ): 156 ):
164 if trained_betas is not None: 157 if trained_betas is not None:
165 self.betas = torch.from_numpy(trained_betas) 158 self.betas = torch.from_numpy(trained_betas)
@@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
177 self.alphas = 1.0 - self.betas 170 self.alphas = 1.0 - self.betas
178 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 171 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
179 172
180 # At every step in ddim, we are looking into the previous alphas_cumprod
181 # For the final step, there is no previous alphas_cumprod because we are already at 0
182 # `set_alpha_to_one` decides whether we set this parameter simply to one or
183 # whether we use the final alpha of the "non-previous" one.
184 self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
185
186 # setable values 173 # setable values
187 self.num_inference_steps = None 174 self.num_inference_steps = None
188 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 175 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
@@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
199 the number of diffusion steps used when generating samples with a pre-trained model. 186 the number of diffusion steps used when generating samples with a pre-trained model.
200 """ 187 """
201 188
202 # offset = self.config.steps_offset
203
204 # if "offset" in kwargs:
205 # warnings.warn(
206 # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
207 # " Please pass `steps_offset` to `__init__` instead.",
208 # DeprecationWarning,
209 # )
210
211 # offset = kwargs["offset"]
212
213 self.num_inference_steps = num_inference_steps 189 self.num_inference_steps = num_inference_steps
214 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 190 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
215 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) 191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
216 self.timesteps = self.sigmas 192 self.timesteps = np.arange(0, self.num_inference_steps)
217 193
218 def add_noise_to_input( 194 def add_noise_to_input(
219 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None 195 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
@@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
239 def step( 215 def step(
240 self, 216 self,
241 model_output: torch.FloatTensor, 217 model_output: torch.FloatTensor,
242 timestep: torch.IntTensor, 218 timestep: int,
243 timestep_prev: torch.IntTensor, 219 timestep_prev: int,
244 sample: torch.FloatTensor, 220 sample: torch.FloatTensor,
245 generator: None, 221 generator: None,
246 # ,sigma_hat: float, 222 # ,sigma_hat: float,
@@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
266 returning a tuple, the first element is the sample tensor. 242 returning a tuple, the first element is the sample tensor.
267 243
268 """ 244 """
245 s = self.sigmas[timestep]
246 s_prev = self.sigmas[timestep_prev]
269 latents = sample 247 latents = sample
270 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) 248
271 d = to_d(latents, timestep, model_output) 249 sigma_down, sigma_up = get_ancestral_step(s, s_prev)
272 dt = sigma_down - timestep 250 d = to_d(latents, s, model_output)
251 dt = sigma_down - s
273 latents = latents + d * dt 252 latents = latents + d * dt
274 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 253 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,
275 generator=generator) * sigma_up 254 generator=generator) * sigma_up
255
276 return SchedulerOutput(prev_sample=latents) 256 return SchedulerOutput(prev_sample=latents)
277 257
278 def step_correct( 258 def step_correct(
@@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
311 291
312 return SchedulerOutput(prev_sample=sample_prev) 292 return SchedulerOutput(prev_sample=sample_prev)
313 293
314 def add_noise(self, original_samples, noise, timesteps): 294 def add_noise(
315 raise NotImplementedError() 295 self,
296 original_samples: torch.FloatTensor,
297 noise: torch.FloatTensor,
298 timesteps: torch.IntTensor,
299 ) -> torch.FloatTensor:
300 sigmas = self.sigmas.to(original_samples.device)
301 timesteps = timesteps.to(original_samples.device)
302
303 sigma = sigmas[timesteps].flatten()
304 while len(sigma.shape) < len(original_samples.shape):
305 sigma = sigma.unsqueeze(-1)
306
307 noisy_samples = original_samples + noise * sigma
308 return noisy_samples