From bd903c95053d2b5c4a85475e9ce3a2037b40d92a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 10:30:43 +0200 Subject: Euler_a: Re-introduce generator arg for reproducible output --- schedulers/scheduling_euler_ancestral_discrete.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py @@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], + generator: torch.Generator = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prev_sample = sample + derivative * dt - prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up + prev_sample = prev_sample + torch.randn( + prev_sample.shape, + layout=prev_sample.layout, + device=prev_sample.device, + dtype=prev_sample.dtype, + generator=generator + ) * sigma_up if not return_dict: return (prev_sample,) -- cgit v1.2.3-54-g00ecf