summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
index 3a2de68..828e0dd 100644
--- a/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
130 timestep: Union[float, torch.FloatTensor], 130 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor], 131 step_index: Union[int, torch.IntTensor],
132 sample: Union[torch.FloatTensor, np.ndarray], 132 sample: Union[torch.FloatTensor, np.ndarray],
133 generator: torch.Generator = None,
133 return_dict: bool = True, 134 return_dict: bool = True,
134 ) -> Union[SchedulerOutput, Tuple]: 135 ) -> Union[SchedulerOutput, Tuple]:
135 """ 136 """
@@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
165 166
166 prev_sample = sample + derivative * dt 167 prev_sample = sample + derivative * dt
167 168
168 prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up 169 prev_sample = prev_sample + torch.randn(
170 prev_sample.shape,
171 layout=prev_sample.layout,
172 device=prev_sample.device,
173 dtype=prev_sample.dtype,
174 generator=generator
175 ) * sigma_up
169 176
170 if not return_dict: 177 if not return_dict:
171 return (prev_sample,) 178 return (prev_sample,)