diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-25 09:11:32 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-25 09:11:32 +0200 |
| commit | 0beac39e60fb4a79edb97a442884684d534722a4 (patch) | |
| tree | 5a5f545155d64906378772d7a5fcbcc6fab2b430 /data | |
| parent | Update (diff) | |
| download | textual-inversion-diff-master.tar.gz textual-inversion-diff-master.tar.bz2 textual-inversion-diff-master.zip | |
Diffstat (limited to 'data')
| -rw-r--r-- | data/prompt.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/data/prompt.py b/data/prompt.py new file mode 100644 index 0000000..0e66196 --- /dev/null +++ b/data/prompt.py | |||
| @@ -0,0 +1,18 @@ | |||
| 1 | from torch.utils.data import Dataset | ||
| 2 | |||
| 3 | |||
| 4 | class PromptDataset(Dataset): | ||
| 5 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." | ||
| 6 | |||
| 7 | def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]): | ||
| 8 | self.prompt_ids = prompt_ids | ||
| 9 | self.nprompt_ids = nprompt_ids | ||
| 10 | |||
| 11 | def __len__(self): | ||
| 12 | return len(self.prompts) | ||
| 13 | |||
| 14 | def __getitem__(self, index): | ||
| 15 | example = {} | ||
| 16 | example["prompt_ids"] = self.prompt_ids[index] | ||
| 17 | example["nprompt_ids"] = self.nprompt_ids[index] | ||
| 18 | return example | ||
