diff options
author | Volpeon <git@volpeon.ink> | 2022-09-27 12:39:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-27 12:39:43 +0200 |
commit | 73fe0a75cd08244f91d1baea7b63b42f9e4be08c (patch) | |
tree | 9d15c9726ad5fbe528ac40a8b91e9d9c0d3cf6fd /data/dreambooth/prompt.py | |
parent | Undo textual inversion dataset improvements (diff) | |
download | textual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.tar.gz textual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.tar.bz2 textual-inversion-diff-73fe0a75cd08244f91d1baea7b63b42f9e4be08c.zip |
Added Dreambooth training script
Diffstat (limited to 'data/dreambooth/prompt.py')
-rw-r--r-- | data/dreambooth/prompt.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py new file mode 100644 index 0000000..34f510d --- /dev/null +++ b/data/dreambooth/prompt.py | |||
@@ -0,0 +1,16 @@ | |||
1 | from torch.utils.data import Dataset | ||
2 | |||
3 | |||
4 | class PromptDataset(Dataset): | ||
5 | def __init__(self, prompt, num_samples): | ||
6 | self.prompt = prompt | ||
7 | self.num_samples = num_samples | ||
8 | |||
9 | def __len__(self): | ||
10 | return self.num_samples | ||
11 | |||
12 | def __getitem__(self, index): | ||
13 | example = {} | ||
14 | example["prompt"] = self.prompt | ||
15 | example["index"] = index | ||
16 | return example | ||