diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 3 |
1 files changed, 0 insertions, 3 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 6abe971..c097a8a 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
47 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
48 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
49 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
50 | tensor_format: str = "pt", | ||
51 | num_inference_steps=None, | 50 | num_inference_steps=None, |
52 | device='cuda' | 51 | device='cuda' |
53 | ): | 52 | ): |
@@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 62 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
64 | 63 | ||
65 | self.device = device | 64 | self.device = device |
66 | self.tensor_format = tensor_format | ||
67 | 65 | ||
68 | self.alphas = 1.0 - self.betas | 66 | self.alphas = 1.0 - self.betas |
69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 67 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
@@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
77 | # get sigmas | 75 | # get sigmas |
78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 76 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | 77 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
80 | self.set_format(tensor_format=tensor_format) | ||
81 | 78 | ||
82 | # A# take number of steps as input | 79 | # A# take number of steps as input |
83 | # A# store 1) number of steps 2) timesteps 3) schedule | 80 | # A# store 1) number of steps 2) timesteps 3) schedule |