diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 431ff3d..280cf77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -439,7 +439,6 @@ def main(): | |||
439 | accelerator = Accelerator( | 439 | accelerator = Accelerator( |
440 | log_with=LoggerType.TENSORBOARD, | 440 | log_with=LoggerType.TENSORBOARD, |
441 | logging_dir=f"{output_dir}", | 441 | logging_dir=f"{output_dir}", |
442 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
443 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
444 | ) | 443 | ) |
445 | 444 | ||
@@ -590,6 +589,7 @@ def main(): | |||
590 | lr_scheduler=lr_scheduler, | 589 | lr_scheduler=lr_scheduler, |
591 | prepare_unet=True, | 590 | prepare_unet=True, |
592 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
592 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
593 | sample_frequency=args.sample_frequency, | 593 | sample_frequency=args.sample_frequency, |
594 | # -- | 594 | # -- |
595 | tokenizer=tokenizer, | 595 | tokenizer=tokenizer, |