diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 09:40:24 +0100 |
| commit | 403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch) | |
| tree | 385c62ef44cc33abc3c5d4b2084c376551137c5f /training/lr.py | |
| parent | Don't use vector_dropout by default (diff) | |
| download | textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2 textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip | |
Fixed reproducibility, more consistant validation
Diffstat (limited to 'training/lr.py')
| -rw-r--r-- | training/lr.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py index 37588b6..a3144ba 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | import copy |
| 3 | from typing import Callable | 3 | from typing import Callable, Any, Tuple, Union |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | 5 | ||
| 6 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
| @@ -24,7 +24,7 @@ class LRFinder(): | |||
| 24 | optimizer, | 24 | optimizer, |
| 25 | train_dataloader, | 25 | train_dataloader, |
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn, | 27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
| 29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
| 30 | ): | 30 | ): |
| @@ -108,7 +108,7 @@ class LRFinder(): | |||
| 108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
| 109 | break | 109 | break |
| 110 | 110 | ||
| 111 | loss, acc, bsz = self.loss_fn(batch) | 111 | loss, acc, bsz = self.loss_fn(batch, True) |
| 112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
| 113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
| 114 | 114 | ||
