summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth/prompt.py')
-rw-r--r--data/dreambooth/prompt.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
new file mode 100644
index 0000000..34f510d
--- /dev/null
+++ b/data/dreambooth/prompt.py
@@ -0,0 +1,16 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples):
6 self.prompt = prompt
7 self.num_samples = num_samples
8
9 def __len__(self):
10 return self.num_samples
11
12 def __getitem__(self, index):
13 example = {}
14 example["prompt"] = self.prompt
15 example["index"] = index
16 return example