diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
@@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
130 | timestep: Union[float, torch.FloatTensor], | 130 | timestep: Union[float, torch.FloatTensor], |
131 | step_index: Union[int, torch.IntTensor], | 131 | step_index: Union[int, torch.IntTensor], |
132 | sample: Union[torch.FloatTensor, np.ndarray], | 132 | sample: Union[torch.FloatTensor, np.ndarray], |
133 | generator: torch.Generator = None, | ||
133 | return_dict: bool = True, | 134 | return_dict: bool = True, |
134 | ) -> Union[SchedulerOutput, Tuple]: | 135 | ) -> Union[SchedulerOutput, Tuple]: |
135 | """ | 136 | """ |
@@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
165 | 166 | ||
166 | prev_sample = sample + derivative * dt | 167 | prev_sample = sample + derivative * dt |
167 | 168 | ||
168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | 169 | prev_sample = prev_sample + torch.randn( |
170 | prev_sample.shape, | ||
171 | layout=prev_sample.layout, | ||
172 | device=prev_sample.device, | ||
173 | dtype=prev_sample.dtype, | ||
174 | generator=generator | ||
175 | ) * sigma_up | ||
169 | 176 | ||
170 | if not return_dict: | 177 | if not return_dict: |
171 | return (prev_sample,) | 178 | return (prev_sample,) |