summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py
index bed7111..bc466e2 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,7 +1,7 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable 4from typing import Iterable, Optional
5from contextlib import contextmanager 5from contextlib import contextmanager
6 6
7import torch 7import torch
@@ -42,15 +42,15 @@ class CheckpointerBase:
42 self, 42 self,
43 datamodule, 43 datamodule,
44 output_dir: Path, 44 output_dir: Path,
45 sample_image_size, 45 sample_image_size: int,
46 sample_batches, 46 sample_batches: int,
47 sample_batch_size, 47 sample_batch_size: int,
48 seed 48 seed: Optional[int] = None
49 ): 49 ):
50 self.datamodule = datamodule 50 self.datamodule = datamodule
51 self.output_dir = output_dir 51 self.output_dir = output_dir
52 self.sample_image_size = sample_image_size 52 self.sample_image_size = sample_image_size
53 self.seed = seed or torch.random.seed() 53 self.seed = seed if seed is not None else torch.random.seed()
54 self.sample_batches = sample_batches 54 self.sample_batches = sample_batches
55 self.sample_batch_size = sample_batch_size 55 self.sample_batch_size = sample_batch_size
56 56