summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/util.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py
index a623dc5..a80e44f 100644
--- a/training/util.py
+++ b/training/util.py
@@ -75,7 +75,7 @@ class CheckpointerBase:
75 ) 75 )
76 76
77 grid_cols = max(self.sample_batch_size, 4) 77 grid_cols = max(self.sample_batch_size, 4)
78 grid_rows = self.sample_batches * self.sample_batch_size / grid_cols 78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
79 79
80 with torch.autocast("cuda"), torch.inference_mode(): 80 with torch.autocast("cuda"), torch.inference_mode():
81 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 81 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: