summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 83043ad..dd015f9 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -774,7 +774,9 @@ def main():
774 num_train_epochs = args.num_train_epochs 774 num_train_epochs = args.num_train_epochs
775 sample_frequency = args.sample_frequency 775 sample_frequency = args.sample_frequency
776 if num_train_epochs is None: 776 if num_train_epochs is None:
777 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 777 num_train_epochs = math.ceil(
778 args.num_train_steps / len(datamodule.train_dataset)
779 ) * args.gradient_accumulation_steps
778 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
779 781
780 optimizer = create_optimizer( 782 optimizer = create_optimizer(