diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 3 |
1 files changed, 0 insertions, 3 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 6abe971..c097a8a 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 47 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
| 48 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
| 49 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
| 50 | tensor_format: str = "pt", | ||
| 51 | num_inference_steps=None, | 50 | num_inference_steps=None, |
| 52 | device='cuda' | 51 | device='cuda' |
| 53 | ): | 52 | ): |
| @@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 62 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 64 | 63 | ||
| 65 | self.device = device | 64 | self.device = device |
| 66 | self.tensor_format = tensor_format | ||
| 67 | 65 | ||
| 68 | self.alphas = 1.0 - self.betas | 66 | self.alphas = 1.0 - self.betas |
| 69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 67 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| @@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 77 | # get sigmas | 75 | # get sigmas |
| 78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 76 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | 77 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
| 80 | self.set_format(tensor_format=tensor_format) | ||
| 81 | 78 | ||
| 82 | # A# take number of steps as input | 79 | # A# take number of steps as input |
| 83 | # A# store 1) number of steps 2) timesteps 3) schedule | 80 | # A# store 1) number of steps 2) timesteps 3) schedule |
