diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
| commit | 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch) | |
| tree | d275e13506ca737efef18dc6dffa05f4e0d6759f /training/util.py | |
| parent | Improved aspect ratio bucketing (diff) | |
| download | textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2 textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip | |
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -73,20 +73,22 @@ class CheckpointerBase: | |||
| 73 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 74 | 74 | ||
| 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
| 76 | prompts = [ | 76 | prompt_ids = [ |
| 77 | prompt | 77 | prompt |
| 78 | for batch in batches | 78 | for batch in batches |
| 79 | for prompt in batch["prompts"] | 79 | for prompt in batch["prompt_ids"] |
| 80 | ] | 80 | ] |
| 81 | nprompts = [ | 81 | nprompt_ids = [ |
| 82 | prompt | 82 | prompt |
| 83 | for batch in batches | 83 | for batch in batches |
| 84 | for prompt in batch["nprompts"] | 84 | for prompt in batch["nprompt_ids"] |
| 85 | ] | 85 | ] |
| 86 | 86 | ||
| 87 | for i in range(self.sample_batches): | 87 | for i in range(self.sample_batches): |
| 88 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 88 | start = i * self.sample_batch_size |
| 89 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 89 | end = (i + 1) * self.sample_batch_size |
| 90 | prompt = prompt_ids[start:end] | ||
| 91 | nprompt = nprompt_ids[start:end] | ||
| 90 | 92 | ||
| 91 | samples = pipeline( | 93 | samples = pipeline( |
| 92 | prompt=prompt, | 94 | prompt=prompt, |
