From 73fe0a75cd08244f91d1baea7b63b42f9e4be08c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 12:39:43 +0200 Subject: Added Dreambooth training script --- data/dreambooth/prompt.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 data/dreambooth/prompt.py (limited to 'data/dreambooth/prompt.py') diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py new file mode 100644 index 0000000..34f510d --- /dev/null +++ b/data/dreambooth/prompt.py @@ -0,0 +1,16 @@ +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example -- cgit v1.2.3-54-g00ecf