summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-25 09:11:32 +0200
committerVolpeon <git@volpeon.ink>2023-06-25 09:11:32 +0200
commit0beac39e60fb4a79edb97a442884684d534722a4 (patch)
tree5a5f545155d64906378772d7a5fcbcc6fab2b430 /data
parentUpdate (diff)
downloadtextual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.tar.gz
textual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.tar.bz2
textual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.zip
Diffstat (limited to 'data')
-rw-r--r--data/prompt.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/data/prompt.py b/data/prompt.py
new file mode 100644
index 0000000..0e66196
--- /dev/null
+++ b/data/prompt.py
@@ -0,0 +1,18 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
6
7 def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]):
8 self.prompt_ids = prompt_ids
9 self.nprompt_ids = nprompt_ids
10
11 def __len__(self):
12 return len(self.prompts)
13
14 def __getitem__(self, index):
15 example = {}
16 example["prompt_ids"] = self.prompt_ids[index]
17 example["nprompt_ids"] = self.nprompt_ids[index]
18 return example