diff options
Diffstat (limited to 'data/dreambooth/prompt.py')
| -rw-r--r-- | data/dreambooth/prompt.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py new file mode 100644 index 0000000..34f510d --- /dev/null +++ b/data/dreambooth/prompt.py | |||
| @@ -0,0 +1,16 @@ | |||
| 1 | from torch.utils.data import Dataset | ||
| 2 | |||
| 3 | |||
| 4 | class PromptDataset(Dataset): | ||
| 5 | def __init__(self, prompt, num_samples): | ||
| 6 | self.prompt = prompt | ||
| 7 | self.num_samples = num_samples | ||
| 8 | |||
| 9 | def __len__(self): | ||
| 10 | return self.num_samples | ||
| 11 | |||
| 12 | def __getitem__(self, index): | ||
| 13 | example = {} | ||
| 14 | example["prompt"] = self.prompt | ||
| 15 | example["index"] = index | ||
| 16 | return example | ||
