blob: 0e66196babb203e7a5f30bf86ce03eb1e02ff7da (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
from torch.utils.data import Dataset
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]):
self.prompt_ids = prompt_ids
self.nprompt_ids = nprompt_ids
def __len__(self):
return len(self.prompts)
def __getitem__(self, index):
example = {}
example["prompt_ids"] = self.prompt_ids[index]
example["nprompt_ids"] = self.nprompt_ids[index]
return example
|