diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 11:50:11 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 11:50:11 +0100 |
| commit | 9d6252e63bac241e5c6191eb47adb51b84a5d782 (patch) | |
| tree | 6cb649510b48ca33419af3721e630f1c06bf1ae2 /train_ti.py | |
| parent | Embedding normalization: Ignore tensors with grad = 0 (diff) | |
| download | textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.gz textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.bz2 textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.zip | |
Don't rely on Accelerate for gradient accumulation
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 6dc07dd..68783ea 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -518,7 +518,6 @@ def main(): | |||
| 518 | accelerator = Accelerator( | 518 | accelerator = Accelerator( |
| 519 | log_with=LoggerType.TENSORBOARD, | 519 | log_with=LoggerType.TENSORBOARD, |
| 520 | logging_dir=f"{output_dir}", | 520 | logging_dir=f"{output_dir}", |
| 521 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 522 | mixed_precision=args.mixed_precision | 521 | mixed_precision=args.mixed_precision |
| 523 | ) | 522 | ) |
| 524 | 523 | ||
| @@ -611,6 +610,7 @@ def main(): | |||
| 611 | low_freq_noise=0, | 610 | low_freq_noise=0, |
| 612 | strategy=textual_inversion_strategy, | 611 | strategy=textual_inversion_strategy, |
| 613 | num_train_epochs=args.num_train_epochs, | 612 | num_train_epochs=args.num_train_epochs, |
| 613 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 614 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
| 615 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
| 616 | milestone_checkpoints=not args.no_milestone_checkpoints, | 616 | milestone_checkpoints=not args.no_milestone_checkpoints, |
