From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- training/util.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py @@ -73,20 +73,22 @@ class CheckpointerBase: file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) - prompts = [ + prompt_ids = [ prompt for batch in batches - for prompt in batch["prompts"] + for prompt in batch["prompt_ids"] ] - nprompts = [ + nprompt_ids = [ prompt for batch in batches - for prompt in batch["nprompts"] + for prompt in batch["nprompt_ids"] ] for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + start = i * self.sample_batch_size + end = (i + 1) * self.sample_batch_size + prompt = prompt_ids[start:end] + nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, -- cgit v1.2.3-54-g00ecf