blob: 34f510d20f5bba9dd5d037fa49630f023c8ca0df (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
from torch.utils.data import Dataset
class PromptDataset(Dataset):
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
|