From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- training/util.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 0ec2032..cc4cdee 100644 --- a/training/util.py +++ b/training/util.py @@ -41,14 +41,16 @@ class AverageMeter: class CheckpointerBase: def __init__( self, - datamodule, + train_dataloader, + val_dataloader, output_dir: Path, sample_image_size: int, sample_batches: int, sample_batch_size: int, seed: Optional[int] = None ): - self.datamodule = datamodule + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed if seed is not None else torch.random.seed() @@ -70,15 +72,16 @@ class CheckpointerBase: ): samples_path = Path(self.output_dir).joinpath("samples") - train_data = self.datamodule.train_dataloader - val_data = self.datamodule.val_dataloader - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: + for pool, data, gen in [ + ("stable", self.val_dataloader, generator), + ("val", self.val_dataloader, None), + ("train", self.train_dataloader, None) + ]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) -- cgit v1.2.3-54-g00ecf