From 0beac39e60fb4a79edb97a442884684d534722a4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 09:11:32 +0200 Subject: Update --- data/prompt.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 data/prompt.py (limited to 'data') diff --git a/data/prompt.py b/data/prompt.py new file mode 100644 index 0000000..0e66196 --- /dev/null +++ b/data/prompt.py @@ -0,0 +1,18 @@ +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]): + self.prompt_ids = prompt_ids + self.nprompt_ids = nprompt_ids + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + example = {} + example["prompt_ids"] = self.prompt_ids[index] + example["nprompt_ids"] = self.nprompt_ids[index] + return example -- cgit v1.2.3-70-g09d2