diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/training/util.py b/training/util.py index 6f42228..2b7f71d 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,6 +1,7 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | import itertools | ||
4 | from typing import Iterable, Optional | 5 | from typing import Iterable, Optional |
5 | from contextlib import contextmanager | 6 | from contextlib import contextmanager |
6 | 7 | ||
@@ -71,13 +72,7 @@ class CheckpointerBase: | |||
71 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 72 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
72 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
73 | 74 | ||
74 | data_enum = enumerate(data) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
75 | |||
76 | batches = [ | ||
77 | batch | ||
78 | for j, batch in data_enum | ||
79 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
80 | ] | ||
81 | prompts = [ | 76 | prompts = [ |
82 | prompt | 77 | prompt |
83 | for batch in batches | 78 | for batch in batches |