summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 775aea2..699313e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=1e-6, 132 default=5e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -167,7 +167,7 @@ def parse_args():
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=7 / 8 170 default=6 / 7
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -270,6 +270,11 @@ def parse_args():
270 help="Max gradient norm." 270 help="Max gradient norm."
271 ) 271 )
272 parser.add_argument( 272 parser.add_argument(
273 "--noise_timesteps",
274 type=int,
275 default=1000,
276 )
277 parser.add_argument(
273 "--config", 278 "--config",
274 type=str, 279 type=str,
275 default=None, 280 default=None,
@@ -480,7 +485,8 @@ def main():
480 unet, 485 unet,
481 inv_gamma=args.ema_inv_gamma, 486 inv_gamma=args.ema_inv_gamma,
482 power=args.ema_power, 487 power=args.ema_power,
483 max_value=args.ema_max_decay 488 max_value=args.ema_max_decay,
489 device=accelerator.device
484 ) if args.use_ema else None 490 ) if args.use_ema else None
485 491
486 if args.gradient_checkpointing: 492 if args.gradient_checkpointing:
@@ -523,7 +529,7 @@ def main():
523 beta_start=0.00085, 529 beta_start=0.00085,
524 beta_end=0.012, 530 beta_end=0.012,
525 beta_schedule="scaled_linear", 531 beta_schedule="scaled_linear",
526 num_train_timesteps=1000 532 num_train_timesteps=args.noise_timesteps
527 ) 533 )
528 534
529 def collate_fn(examples): 535 def collate_fn(examples):
@@ -632,7 +638,6 @@ def main():
632 # Move text_encoder and vae to device 638 # Move text_encoder and vae to device
633 text_encoder.to(accelerator.device) 639 text_encoder.to(accelerator.device)
634 vae.to(accelerator.device) 640 vae.to(accelerator.device)
635 ema_unet.averaged_model.to(accelerator.device)
636 641
637 # Keep text_encoder and vae in eval mode as we don't train these 642 # Keep text_encoder and vae in eval mode as we don't train these
638 text_encoder.eval() 643 text_encoder.eval()