blob: b3a83ce053485fac90223b72f39222d12accc7c0 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
from torch.utils.data import Dataset
class PromptDataset(Dataset):
def __init__(self, prompt, nprompt, num_samples):
self.prompt = prompt
self.nprompt = nprompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["nprompt"] = self.nprompt
example["index"] = index
return example
|