diff options
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/training/util.py b/training/util.py index 6f42228..2b7f71d 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | import itertools | ||
| 4 | from typing import Iterable, Optional | 5 | from typing import Iterable, Optional |
| 5 | from contextlib import contextmanager | 6 | from contextlib import contextmanager |
| 6 | 7 | ||
| @@ -71,13 +72,7 @@ class CheckpointerBase: | |||
| 71 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 72 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 72 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 73 | 74 | ||
| 74 | data_enum = enumerate(data) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
| 75 | |||
| 76 | batches = [ | ||
| 77 | batch | ||
| 78 | for j, batch in data_enum | ||
| 79 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 80 | ] | ||
| 81 | prompts = [ | 76 | prompts = [ |
| 82 | prompt | 77 | prompt |
| 83 | for batch in batches | 78 | for batch in batches |
