From b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 17:34:23 +0100 Subject: Update --- training/util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/util.py b/training/util.py index bed7111..bc466e2 100644 --- a/training/util.py +++ b/training/util.py @@ -1,7 +1,7 @@ from pathlib import Path import json import copy -from typing import Iterable +from typing import Iterable, Optional from contextlib import contextmanager import torch @@ -42,15 +42,15 @@ class CheckpointerBase: self, datamodule, output_dir: Path, - sample_image_size, - sample_batches, - sample_batch_size, - seed + sample_image_size: int, + sample_batches: int, + sample_batch_size: int, + seed: Optional[int] = None ): self.datamodule = datamodule self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed or torch.random.seed() + self.seed = seed if seed is not None else torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size -- cgit v1.2.3-54-g00ecf