diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
| @@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
| 130 | timestep: Union[float, torch.FloatTensor], | 130 | timestep: Union[float, torch.FloatTensor], |
| 131 | step_index: Union[int, torch.IntTensor], | 131 | step_index: Union[int, torch.IntTensor], |
| 132 | sample: Union[torch.FloatTensor, np.ndarray], | 132 | sample: Union[torch.FloatTensor, np.ndarray], |
| 133 | generator: torch.Generator = None, | ||
| 133 | return_dict: bool = True, | 134 | return_dict: bool = True, |
| 134 | ) -> Union[SchedulerOutput, Tuple]: | 135 | ) -> Union[SchedulerOutput, Tuple]: |
| 135 | """ | 136 | """ |
| @@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
| 165 | 166 | ||
| 166 | prev_sample = sample + derivative * dt | 167 | prev_sample = sample + derivative * dt |
| 167 | 168 | ||
| 168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | 169 | prev_sample = prev_sample + torch.randn( |
| 170 | prev_sample.shape, | ||
| 171 | layout=prev_sample.layout, | ||
| 172 | device=prev_sample.device, | ||
| 173 | dtype=prev_sample.dtype, | ||
| 174 | generator=generator | ||
| 175 | ) * sigma_up | ||
| 169 | 176 | ||
| 170 | if not return_dict: | 177 | if not return_dict: |
| 171 | return (prev_sample,) | 178 | return (prev_sample,) |
