diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py index bed7111..bc466e2 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable | 4 | from typing import Iterable, Optional |
5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
6 | 6 | ||
7 | import torch | 7 | import torch |
@@ -42,15 +42,15 @@ class CheckpointerBase: | |||
42 | self, | 42 | self, |
43 | datamodule, | 43 | datamodule, |
44 | output_dir: Path, | 44 | output_dir: Path, |
45 | sample_image_size, | 45 | sample_image_size: int, |
46 | sample_batches, | 46 | sample_batches: int, |
47 | sample_batch_size, | 47 | sample_batch_size: int, |
48 | seed | 48 | seed: Optional[int] = None |
49 | ): | 49 | ): |
50 | self.datamodule = datamodule | 50 | self.datamodule = datamodule |
51 | self.output_dir = output_dir | 51 | self.output_dir = output_dir |
52 | self.sample_image_size = sample_image_size | 52 | self.sample_image_size = sample_image_size |
53 | self.seed = seed or torch.random.seed() | 53 | self.seed = seed if seed is not None else torch.random.seed() |
54 | self.sample_batches = sample_batches | 54 | self.sample_batches = sample_batches |
55 | self.sample_batch_size = sample_batch_size | 55 | self.sample_batch_size = sample_batch_size |
56 | 56 | ||