summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index b7ea90d..78a2b10 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -353,7 +353,6 @@ def train_loop(
353 train_dataloader: DataLoader, 353 train_dataloader: DataLoader,
354 val_dataloader: Optional[DataLoader], 354 val_dataloader: Optional[DataLoader],
355 loss_step: LossCallable, 355 loss_step: LossCallable,
356 no_val: bool = False,
357 sample_frequency: int = 10, 356 sample_frequency: int = 10,
358 checkpoint_frequency: int = 50, 357 checkpoint_frequency: int = 50,
359 global_step_offset: int = 0, 358 global_step_offset: int = 0,
@@ -472,7 +471,7 @@ def train_loop(
472 471
473 on_after_epoch(lr_scheduler.get_last_lr()[0]) 472 on_after_epoch(lr_scheduler.get_last_lr()[0])
474 473
475 if val_dataloader is not None and not no_val: 474 if val_dataloader is not None:
476 model.eval() 475 model.eval()
477 476
478 cur_loss_val = AverageMeter() 477 cur_loss_val = AverageMeter()
@@ -616,8 +615,7 @@ def train(
616 optimizer=optimizer, 615 optimizer=optimizer,
617 lr_scheduler=lr_scheduler, 616 lr_scheduler=lr_scheduler,
618 train_dataloader=train_dataloader, 617 train_dataloader=train_dataloader,
619 val_dataloader=val_dataloader, 618 val_dataloader=val_dataloader if not no_val else None,
620 no_val=no_val,
621 loss_step=loss_step_, 619 loss_step=loss_step_,
622 sample_frequency=sample_frequency, 620 sample_frequency=sample_frequency,
623 checkpoint_frequency=checkpoint_frequency, 621 checkpoint_frequency=checkpoint_frequency,