summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 1548784..3d27380 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -87,8 +87,6 @@ def save_samples(
87): 87):
88 print(f"Saving samples for step {step}...") 88 print(f"Saving samples for step {step}...")
89 89
90 samples_path = output_dir.joinpath("samples")
91
92 grid_cols = min(batch_size, 4) 90 grid_cols = min(batch_size, 4)
93 grid_rows = (num_batches * batch_size) // grid_cols 91 grid_rows = (num_batches * batch_size) // grid_cols
94 92
@@ -120,7 +118,7 @@ def save_samples(
120 118
121 for pool, data, gen in datasets: 119 for pool, data, gen in datasets:
122 all_samples = [] 120 all_samples = []
123 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 121 file_path = output_dir.joinpath(pool, f"step_{step}.jpg")
124 file_path.parent.mkdir(parents=True, exist_ok=True) 122 file_path.parent.mkdir(parents=True, exist_ok=True)
125 123
126 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 124 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))