summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py
index ae6bfc4..60d64f0 100644
--- a/training/util.py
+++ b/training/util.py
@@ -73,20 +73,22 @@ class CheckpointerBase:
73 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
74 74
75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
76 prompts = [ 76 prompt_ids = [
77 prompt 77 prompt
78 for batch in batches 78 for batch in batches
79 for prompt in batch["prompts"] 79 for prompt in batch["prompt_ids"]
80 ] 80 ]
81 nprompts = [ 81 nprompt_ids = [
82 prompt 82 prompt
83 for batch in batches 83 for batch in batches
84 for prompt in batch["nprompts"] 84 for prompt in batch["nprompt_ids"]
85 ] 85 ]
86 86
87 for i in range(self.sample_batches): 87 for i in range(self.sample_batches):
88 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 88 start = i * self.sample_batch_size
89 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 89 end = (i + 1) * self.sample_batch_size
90 prompt = prompt_ids[start:end]
91 nprompt = nprompt_ids[start:end]
90 92
91 samples = pipeline( 93 samples = pipeline(
92 prompt=prompt, 94 prompt=prompt,