from torch.utils.data import Dataset class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]): self.prompt_ids = prompt_ids self.nprompt_ids = nprompt_ids def __len__(self): return len(self.prompts) def __getitem__(self, index): example = {} example["prompt_ids"] = self.prompt_ids[index] example["nprompt_ids"] = self.nprompt_ids[index] return example