diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-27 10:30:43 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-27 10:30:43 +0200 |
| commit | bd903c95053d2b5c4a85475e9ce3a2037b40d92a (patch) | |
| tree | de0d52394ea62a69073432dbad12cbf999a64de3 | |
| parent | New Euler_a scheduler (diff) | |
| download | textual-inversion-diff-bd903c95053d2b5c4a85475e9ce3a2037b40d92a.tar.gz textual-inversion-diff-bd903c95053d2b5c4a85475e9ce3a2037b40d92a.tar.bz2 textual-inversion-diff-bd903c95053d2b5c4a85475e9ce3a2037b40d92a.zip | |
Euler_a: Re-introduce generator arg for reproducible output
| -rw-r--r-- | dreambooth.py | 4 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 9 |
2 files changed, 10 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index a181293..db097e5 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -193,7 +193,7 @@ def parse_args(): | |||
| 193 | parser.add_argument( | 193 | parser.add_argument( |
| 194 | "--ema_power", | 194 | "--ema_power", |
| 195 | type=float, | 195 | type=float, |
| 196 | default=9 / 10 | 196 | default=5 / 6 |
| 197 | ) | 197 | ) |
| 198 | parser.add_argument( | 198 | parser.add_argument( |
| 199 | "--ema_max_decay", | 199 | "--ema_max_decay", |
| @@ -957,7 +957,7 @@ def main(): | |||
| 957 | "lr/text": lr_scheduler.get_last_lr()[1] | 957 | "lr/text": lr_scheduler.get_last_lr()[1] |
| 958 | } | 958 | } |
| 959 | if args.use_ema: | 959 | if args.use_ema: |
| 960 | logs["ema_decay"] = ema_unet.decay | 960 | logs["ema_decay"] = 1 - ema_unet.decay |
| 961 | 961 | ||
| 962 | accelerator.log(logs, step=global_step) | 962 | accelerator.log(logs, step=global_step) |
| 963 | 963 | ||
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
| @@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
| 130 | timestep: Union[float, torch.FloatTensor], | 130 | timestep: Union[float, torch.FloatTensor], |
| 131 | step_index: Union[int, torch.IntTensor], | 131 | step_index: Union[int, torch.IntTensor], |
| 132 | sample: Union[torch.FloatTensor, np.ndarray], | 132 | sample: Union[torch.FloatTensor, np.ndarray], |
| 133 | generator: torch.Generator = None, | ||
| 133 | return_dict: bool = True, | 134 | return_dict: bool = True, |
| 134 | ) -> Union[SchedulerOutput, Tuple]: | 135 | ) -> Union[SchedulerOutput, Tuple]: |
| 135 | """ | 136 | """ |
| @@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
| 165 | 166 | ||
| 166 | prev_sample = sample + derivative * dt | 167 | prev_sample = sample + derivative * dt |
| 167 | 168 | ||
| 168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | 169 | prev_sample = prev_sample + torch.randn( |
| 170 | prev_sample.shape, | ||
| 171 | layout=prev_sample.layout, | ||
| 172 | device=prev_sample.device, | ||
| 173 | dtype=prev_sample.dtype, | ||
| 174 | generator=generator | ||
| 175 | ) * sigma_up | ||
| 169 | 176 | ||
| 170 | if not return_dict: | 177 | if not return_dict: |
| 171 | return (prev_sample,) | 178 | return (prev_sample,) |
