diff options
author | Volpeon <git@volpeon.ink> | 2023-02-16 09:20:40 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-16 09:20:40 +0100 |
commit | 44be0ef1f50313b9a4290bb78c45334187d1ab56 (patch) | |
tree | d29bd201008ae9813f52defd0fdf04cb8e945cd4 /training | |
parent | Integrated WIP UniPC scheduler (diff) | |
download | textual-inversion-diff-44be0ef1f50313b9a4290bb78c45334187d1ab56.tar.gz textual-inversion-diff-44be0ef1f50313b9a4290bb78c45334187d1ab56.tar.bz2 textual-inversion-diff-44be0ef1f50313b9a4290bb78c45334187d1ab56.zip |
Fix
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index b7ea90d..78a2b10 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -353,7 +353,6 @@ def train_loop( | |||
353 | train_dataloader: DataLoader, | 353 | train_dataloader: DataLoader, |
354 | val_dataloader: Optional[DataLoader], | 354 | val_dataloader: Optional[DataLoader], |
355 | loss_step: LossCallable, | 355 | loss_step: LossCallable, |
356 | no_val: bool = False, | ||
357 | sample_frequency: int = 10, | 356 | sample_frequency: int = 10, |
358 | checkpoint_frequency: int = 50, | 357 | checkpoint_frequency: int = 50, |
359 | global_step_offset: int = 0, | 358 | global_step_offset: int = 0, |
@@ -472,7 +471,7 @@ def train_loop( | |||
472 | 471 | ||
473 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 472 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
474 | 473 | ||
475 | if val_dataloader is not None and not no_val: | 474 | if val_dataloader is not None: |
476 | model.eval() | 475 | model.eval() |
477 | 476 | ||
478 | cur_loss_val = AverageMeter() | 477 | cur_loss_val = AverageMeter() |
@@ -616,8 +615,7 @@ def train( | |||
616 | optimizer=optimizer, | 615 | optimizer=optimizer, |
617 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
618 | train_dataloader=train_dataloader, | 617 | train_dataloader=train_dataloader, |
619 | val_dataloader=val_dataloader, | 618 | val_dataloader=val_dataloader if not no_val else None, |
620 | no_val=no_val, | ||
621 | loss_step=loss_step_, | 619 | loss_step=loss_step_, |
622 | sample_frequency=sample_frequency, | 620 | sample_frequency=sample_frequency, |
623 | checkpoint_frequency=checkpoint_frequency, | 621 | checkpoint_frequency=checkpoint_frequency, |