summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py4
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py9
2 files changed, 10 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py
index a181293..db097e5 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -193,7 +193,7 @@ def parse_args():
193 parser.add_argument( 193 parser.add_argument(
194 "--ema_power", 194 "--ema_power",
195 type=float, 195 type=float,
196 default=9 / 10 196 default=5 / 6
197 ) 197 )
198 parser.add_argument( 198 parser.add_argument(
199 "--ema_max_decay", 199 "--ema_max_decay",
@@ -957,7 +957,7 @@ def main():
957 "lr/text": lr_scheduler.get_last_lr()[1] 957 "lr/text": lr_scheduler.get_last_lr()[1]
958 } 958 }
959 if args.use_ema: 959 if args.use_ema:
960 logs["ema_decay"] = ema_unet.decay 960 logs["ema_decay"] = 1 - ema_unet.decay
961 961
962 accelerator.log(logs, step=global_step) 962 accelerator.log(logs, step=global_step)
963 963
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
index 3a2de68..828e0dd 100644
--- a/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -130,6 +130,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
130 timestep: Union[float, torch.FloatTensor], 130 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor], 131 step_index: Union[int, torch.IntTensor],
132 sample: Union[torch.FloatTensor, np.ndarray], 132 sample: Union[torch.FloatTensor, np.ndarray],
133 generator: torch.Generator = None,
133 return_dict: bool = True, 134 return_dict: bool = True,
134 ) -> Union[SchedulerOutput, Tuple]: 135 ) -> Union[SchedulerOutput, Tuple]:
135 """ 136 """
@@ -165,7 +166,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
165 166
166 prev_sample = sample + derivative * dt 167 prev_sample = sample + derivative * dt
167 168
168 prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up 169 prev_sample = prev_sample + torch.randn(
170 prev_sample.shape,
171 layout=prev_sample.layout,
172 device=prev_sample.device,
173 dtype=prev_sample.dtype,
174 generator=generator
175 ) * sigma_up
169 176
170 if not return_dict: 177 if not return_dict:
171 return (prev_sample,) 178 return (prev_sample,)