From bd903c95053d2b5c4a85475e9ce3a2037b40d92a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 10:30:43 +0200 Subject: Euler_a: Re-introduce generator arg for reproducible output --- dreambooth.py | 4 ++-- schedulers/scheduling_euler_ancestral_discrete.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index a181293..db097e5 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -193,7 +193,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=9 / 10 + default=5 / 6 ) parser.add_argument( "--ema_max_decay", @@ -957,7 +957,7 @@ def main(): "lr/text": lr_scheduler.get_last_lr()[1] } if args.use_ema: - logs["ema_decay"] = ema_unet.decay + logs["ema_decay"] = 1 - ema_unet.decay accelerator.log(logs, step=global_step) diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py @@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], + generator: torch.Generator = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prev_sample = sample + derivative * dt - prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up + prev_sample = prev_sample + torch.randn( + prev_sample.shape, + layout=prev_sample.layout, + device=prev_sample.device, + dtype=prev_sample.dtype, + generator=generator + ) * sigma_up if not return_dict: return (prev_sample,) -- cgit v1.2.3-54-g00ecf