summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py9
1 files changed, 2 insertions, 7 deletions
diff --git a/training/util.py b/training/util.py
index 6f42228..2b7f71d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,6 +1,7 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools
4from typing import Iterable, Optional 5from typing import Iterable, Optional
5from contextlib import contextmanager 6from contextlib import contextmanager
6 7
@@ -71,13 +72,7 @@ class CheckpointerBase:
71 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 72 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
72 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
73 74
74 data_enum = enumerate(data) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
75
76 batches = [
77 batch
78 for j, batch in data_enum
79 if j * data.batch_size < self.sample_batch_size * self.sample_batches
80 ]
81 prompts = [ 76 prompts = [
82 prompt 77 prompt
83 for batch in batches 78 for batch in batches