diff options
-rw-r--r-- | dreambooth.py | 4 | ||||
-rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 9 |
2 files changed, 10 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index a181293..db097e5 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -193,7 +193,7 @@ def parse_args(): | |||
193 | parser.add_argument( | 193 | parser.add_argument( |
194 | "--ema_power", | 194 | "--ema_power", |
195 | type=float, | 195 | type=float, |
196 | default=9 / 10 | 196 | default=5 / 6 |
197 | ) | 197 | ) |
198 | parser.add_argument( | 198 | parser.add_argument( |
199 | "--ema_max_decay", | 199 | "--ema_max_decay", |
@@ -957,7 +957,7 @@ def main(): | |||
957 | "lr/text": lr_scheduler.get_last_lr()[1] | 957 | "lr/text": lr_scheduler.get_last_lr()[1] |
958 | } | 958 | } |
959 | if args.use_ema: | 959 | if args.use_ema: |
960 | logs["ema_decay"] = ema_unet.decay | 960 | logs["ema_decay"] = 1 - ema_unet.decay |
961 | 961 | ||
962 | accelerator.log(logs, step=global_step) | 962 | accelerator.log(logs, step=global_step) |
963 | 963 | ||
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 3a2de68..828e0dd 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
@@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
130 | timestep: Union[float, torch.FloatTensor], | 130 | timestep: Union[float, torch.FloatTensor], |
131 | step_index: Union[int, torch.IntTensor], | 131 | step_index: Union[int, torch.IntTensor], |
132 | sample: Union[torch.FloatTensor, np.ndarray], | 132 | sample: Union[torch.FloatTensor, np.ndarray], |
133 | generator: torch.Generator = None, | ||
133 | return_dict: bool = True, | 134 | return_dict: bool = True, |
134 | ) -> Union[SchedulerOutput, Tuple]: | 135 | ) -> Union[SchedulerOutput, Tuple]: |
135 | """ | 136 | """ |
@@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | |||
165 | 166 | ||
166 | prev_sample = sample + derivative * dt | 167 | prev_sample = sample + derivative * dt |
167 | 168 | ||
168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | 169 | prev_sample = prev_sample + torch.randn( |
170 | prev_sample.shape, | ||
171 | layout=prev_sample.layout, | ||
172 | device=prev_sample.device, | ||
173 | dtype=prev_sample.dtype, | ||
174 | generator=generator | ||
175 | ) * sigma_up | ||
169 | 176 | ||
170 | if not return_dict: | 177 | if not return_dict: |
171 | return (prev_sample,) | 178 | return (prev_sample,) |