summaryrefslogtreecommitdiffstats
path: root/data/prompt.py
blob: 0e66196babb203e7a5f30bf86ce03eb1e02ff7da (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch.utils.data import Dataset


class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]):
        self.prompt_ids = prompt_ids
        self.nprompt_ids = nprompt_ids

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index):
        example = {}
        example["prompt_ids"] = self.prompt_ids[index]
        example["nprompt_ids"] = self.nprompt_ids[index]
        return example