summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/prompt.py
blob: 34f510d20f5bba9dd5d037fa49630f023c8ca0df (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.utils.data import Dataset


class PromptDataset(Dataset):
    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example