summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 11:50:11 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 11:50:11 +0100
commit9d6252e63bac241e5c6191eb47adb51b84a5d782 (patch)
tree6cb649510b48ca33419af3721e630f1c06bf1ae2 /train_ti.py
parentEmbedding normalization: Ignore tensors with grad = 0 (diff)
downloadtextual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.gz
textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.bz2
textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.zip
Don't rely on Accelerate for gradient accumulation
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 6dc07dd..68783ea 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -518,7 +518,6 @@ def main():
518 accelerator = Accelerator( 518 accelerator = Accelerator(
519 log_with=LoggerType.TENSORBOARD, 519 log_with=LoggerType.TENSORBOARD,
520 logging_dir=f"{output_dir}", 520 logging_dir=f"{output_dir}",
521 gradient_accumulation_steps=args.gradient_accumulation_steps,
522 mixed_precision=args.mixed_precision 521 mixed_precision=args.mixed_precision
523 ) 522 )
524 523
@@ -611,6 +610,7 @@ def main():
611 low_freq_noise=0, 610 low_freq_noise=0,
612 strategy=textual_inversion_strategy, 611 strategy=textual_inversion_strategy,
613 num_train_epochs=args.num_train_epochs, 612 num_train_epochs=args.num_train_epochs,
613 gradient_accumulation_steps=args.gradient_accumulation_steps,
614 sample_frequency=args.sample_frequency, 614 sample_frequency=args.sample_frequency,
615 checkpoint_frequency=args.checkpoint_frequency, 615 checkpoint_frequency=args.checkpoint_frequency,
616 milestone_checkpoints=not args.no_milestone_checkpoints, 616 milestone_checkpoints=not args.no_milestone_checkpoints,