diff options
| -rw-r--r-- | training/functional.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index b7ea90d..78a2b10 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -353,7 +353,6 @@ def train_loop( | |||
| 353 | train_dataloader: DataLoader, | 353 | train_dataloader: DataLoader, |
| 354 | val_dataloader: Optional[DataLoader], | 354 | val_dataloader: Optional[DataLoader], |
| 355 | loss_step: LossCallable, | 355 | loss_step: LossCallable, |
| 356 | no_val: bool = False, | ||
| 357 | sample_frequency: int = 10, | 356 | sample_frequency: int = 10, |
| 358 | checkpoint_frequency: int = 50, | 357 | checkpoint_frequency: int = 50, |
| 359 | global_step_offset: int = 0, | 358 | global_step_offset: int = 0, |
| @@ -472,7 +471,7 @@ def train_loop( | |||
| 472 | 471 | ||
| 473 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 472 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
| 474 | 473 | ||
| 475 | if val_dataloader is not None and not no_val: | 474 | if val_dataloader is not None: |
| 476 | model.eval() | 475 | model.eval() |
| 477 | 476 | ||
| 478 | cur_loss_val = AverageMeter() | 477 | cur_loss_val = AverageMeter() |
| @@ -616,8 +615,7 @@ def train( | |||
| 616 | optimizer=optimizer, | 615 | optimizer=optimizer, |
| 617 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
| 618 | train_dataloader=train_dataloader, | 617 | train_dataloader=train_dataloader, |
| 619 | val_dataloader=val_dataloader, | 618 | val_dataloader=val_dataloader if not no_val else None, |
| 620 | no_val=no_val, | ||
| 621 | loss_step=loss_step_, | 619 | loss_step=loss_step_, |
| 622 | sample_frequency=sample_frequency, | 620 | sample_frequency=sample_frequency, |
| 623 | checkpoint_frequency=checkpoint_frequency, | 621 | checkpoint_frequency=checkpoint_frequency, |
