summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
commit7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch)
treed275e13506ca737efef18dc6dffa05f4e0d6759f /training
parentImproved aspect ratio bucketing (diff)
downloadtextual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'training')
-rw-r--r--training/util.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py
index ae6bfc4..60d64f0 100644
--- a/training/util.py
+++ b/training/util.py
@@ -73,20 +73,22 @@ class CheckpointerBase:
73 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
74 74
75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
76 prompts = [ 76 prompt_ids = [
77 prompt 77 prompt
78 for batch in batches 78 for batch in batches
79 for prompt in batch["prompts"] 79 for prompt in batch["prompt_ids"]
80 ] 80 ]
81 nprompts = [ 81 nprompt_ids = [
82 prompt 82 prompt
83 for batch in batches 83 for batch in batches
84 for prompt in batch["nprompts"] 84 for prompt in batch["nprompt_ids"]
85 ] 85 ]
86 86
87 for i in range(self.sample_batches): 87 for i in range(self.sample_batches):
88 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 88 start = i * self.sample_batch_size
89 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 89 end = (i + 1) * self.sample_batch_size
90 prompt = prompt_ids[start:end]
91 nprompt = nprompt_ids[start:end]
90 92
91 samples = pipeline( 93 samples = pipeline(
92 prompt=prompt, 94 prompt=prompt,