summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/prompt.py
blob: b3a83ce053485fac90223b72f39222d12accc7c0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch.utils.data import Dataset


class PromptDataset(Dataset):
    def __init__(self, prompt, nprompt, num_samples):
        self.prompt = prompt
        self.nprompt = nprompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["nprompt"] = self.nprompt
        example["index"] = index
        return example