summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/prompt.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 12:39:43 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 12:39:43 +0200
commit73fe0a75cd08244f91d1baea7b63b42f9e4be08c (patch)
tree9d15c9726ad5fbe528ac40a8b91e9d9c0d3cf6fd /data/dreambooth/prompt.py
parentUndo textual inversion dataset improvements (diff)
downloadtextual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.tar.gz
textual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.tar.bz2
textual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.zip
Added Dreambooth training script
Diffstat (limited to 'data/dreambooth/prompt.py')
-rw-r--r--data/dreambooth/prompt.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
new file mode 100644
index 0000000..34f510d
--- /dev/null
+++ b/data/dreambooth/prompt.py
@@ -0,0 +1,16 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples):
6 self.prompt = prompt
7 self.num_samples = num_samples
8
9 def __len__(self):
10 return self.num_samples
11
12 def __getitem__(self, index):
13 example = {}
14 example["prompt"] = self.prompt
15 example["index"] = index
16 return example