diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 775aea2..699313e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -129,7 +129,7 @@ def parse_args(): | |||
129 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--learning_rate", | 130 | "--learning_rate", |
131 | type=float, | 131 | type=float, |
132 | default=1e-6, | 132 | default=5e-6, |
133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
@@ -167,7 +167,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 167 | parser.add_argument( |
168 | "--ema_power", | 168 | "--ema_power", |
169 | type=float, | 169 | type=float, |
170 | default=7 / 8 | 170 | default=6 / 7 |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
@@ -270,6 +270,11 @@ def parse_args(): | |||
270 | help="Max gradient norm." | 270 | help="Max gradient norm." |
271 | ) | 271 | ) |
272 | parser.add_argument( | 272 | parser.add_argument( |
273 | "--noise_timesteps", | ||
274 | type=int, | ||
275 | default=1000, | ||
276 | ) | ||
277 | parser.add_argument( | ||
273 | "--config", | 278 | "--config", |
274 | type=str, | 279 | type=str, |
275 | default=None, | 280 | default=None, |
@@ -480,7 +485,8 @@ def main(): | |||
480 | unet, | 485 | unet, |
481 | inv_gamma=args.ema_inv_gamma, | 486 | inv_gamma=args.ema_inv_gamma, |
482 | power=args.ema_power, | 487 | power=args.ema_power, |
483 | max_value=args.ema_max_decay | 488 | max_value=args.ema_max_decay, |
489 | device=accelerator.device | ||
484 | ) if args.use_ema else None | 490 | ) if args.use_ema else None |
485 | 491 | ||
486 | if args.gradient_checkpointing: | 492 | if args.gradient_checkpointing: |
@@ -523,7 +529,7 @@ def main(): | |||
523 | beta_start=0.00085, | 529 | beta_start=0.00085, |
524 | beta_end=0.012, | 530 | beta_end=0.012, |
525 | beta_schedule="scaled_linear", | 531 | beta_schedule="scaled_linear", |
526 | num_train_timesteps=1000 | 532 | num_train_timesteps=args.noise_timesteps |
527 | ) | 533 | ) |
528 | 534 | ||
529 | def collate_fn(examples): | 535 | def collate_fn(examples): |
@@ -632,7 +638,6 @@ def main(): | |||
632 | # Move text_encoder and vae to device | 638 | # Move text_encoder and vae to device |
633 | text_encoder.to(accelerator.device) | 639 | text_encoder.to(accelerator.device) |
634 | vae.to(accelerator.device) | 640 | vae.to(accelerator.device) |
635 | ema_unet.averaged_model.to(accelerator.device) | ||
636 | 641 | ||
637 | # Keep text_encoder and vae in eval mode as we don't train these | 642 | # Keep text_encoder and vae in eval mode as we don't train these |
638 | text_encoder.eval() | 643 | text_encoder.eval() |