summaryrefslogtreecommitdiffstats
path: root/data/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/prompt.py')
-rw-r--r--data/prompt.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/data/prompt.py b/data/prompt.py
new file mode 100644
index 0000000..0e66196
--- /dev/null
+++ b/data/prompt.py
@@ -0,0 +1,18 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
6
7 def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]):
8 self.prompt_ids = prompt_ids
9 self.nprompt_ids = nprompt_ids
10
11 def __len__(self):
12 return len(self.prompts)
13
14 def __getitem__(self, index):
15 example = {}
16 example["prompt_ids"] = self.prompt_ids[index]
17 example["nprompt_ids"] = self.nprompt_ids[index]
18 return example