diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
commit | 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch) | |
tree | d275e13506ca737efef18dc6dffa05f4e0d6759f /training | |
parent | Improved aspect ratio bucketing (diff) | |
download | textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2 textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip |
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -73,20 +73,22 @@ class CheckpointerBase: | |||
73 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
74 | 74 | ||
75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
76 | prompts = [ | 76 | prompt_ids = [ |
77 | prompt | 77 | prompt |
78 | for batch in batches | 78 | for batch in batches |
79 | for prompt in batch["prompts"] | 79 | for prompt in batch["prompt_ids"] |
80 | ] | 80 | ] |
81 | nprompts = [ | 81 | nprompt_ids = [ |
82 | prompt | 82 | prompt |
83 | for batch in batches | 83 | for batch in batches |
84 | for prompt in batch["nprompts"] | 84 | for prompt in batch["nprompt_ids"] |
85 | ] | 85 | ] |
86 | 86 | ||
87 | for i in range(self.sample_batches): | 87 | for i in range(self.sample_batches): |
88 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 88 | start = i * self.sample_batch_size |
89 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 89 | end = (i + 1) * self.sample_batch_size |
90 | prompt = prompt_ids[start:end] | ||
91 | nprompt = nprompt_ids[start:end] | ||
90 | 92 | ||
91 | samples = pipeline( | 93 | samples = pipeline( |
92 | prompt=prompt, | 94 | prompt=prompt, |