diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -362,6 +362,7 @@ def train_loop( | |||
362 | loss_step: LossCallable, | 362 | loss_step: LossCallable, |
363 | sample_frequency: int = 10, | 363 | sample_frequency: int = 10, |
364 | checkpoint_frequency: int = 50, | 364 | checkpoint_frequency: int = 50, |
365 | milestone_checkpoints: bool = True, | ||
365 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
366 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
367 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
@@ -514,7 +515,7 @@ def train_loop( | |||
514 | accelerator.log(logs, step=global_step) | 515 | accelerator.log(logs, step=global_step) |
515 | 516 | ||
516 | if accelerator.is_main_process: | 517 | if accelerator.is_main_process: |
517 | if avg_acc_val.avg.item() > best_acc_val: | 518 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: |
518 | local_progress_bar.clear() | 519 | local_progress_bar.clear() |
519 | global_progress_bar.clear() | 520 | global_progress_bar.clear() |
520 | 521 | ||
@@ -527,7 +528,7 @@ def train_loop( | |||
527 | accs.append(avg_acc_val.avg.item()) | 528 | accs.append(avg_acc_val.avg.item()) |
528 | else: | 529 | else: |
529 | if accelerator.is_main_process: | 530 | if accelerator.is_main_process: |
530 | if avg_acc.avg.item() > best_acc: | 531 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
531 | local_progress_bar.clear() | 532 | local_progress_bar.clear() |
532 | global_progress_bar.clear() | 533 | global_progress_bar.clear() |
533 | 534 | ||
@@ -572,6 +573,7 @@ def train( | |||
572 | num_train_epochs: int = 100, | 573 | num_train_epochs: int = 100, |
573 | sample_frequency: int = 20, | 574 | sample_frequency: int = 20, |
574 | checkpoint_frequency: int = 50, | 575 | checkpoint_frequency: int = 50, |
576 | milestone_checkpoints: bool = True, | ||
575 | global_step_offset: int = 0, | 577 | global_step_offset: int = 0, |
576 | with_prior_preservation: bool = False, | 578 | with_prior_preservation: bool = False, |
577 | prior_loss_weight: float = 1.0, | 579 | prior_loss_weight: float = 1.0, |
@@ -626,6 +628,7 @@ def train( | |||
626 | loss_step=loss_step_, | 628 | loss_step=loss_step_, |
627 | sample_frequency=sample_frequency, | 629 | sample_frequency=sample_frequency, |
628 | checkpoint_frequency=checkpoint_frequency, | 630 | checkpoint_frequency=checkpoint_frequency, |
631 | milestone_checkpoints=milestone_checkpoints, | ||
629 | global_step_offset=global_step_offset, | 632 | global_step_offset=global_step_offset, |
630 | num_epochs=num_train_epochs, | 633 | num_epochs=num_train_epochs, |
631 | callbacks=callbacks, | 634 | callbacks=callbacks, |