summaryrefslogtreecommitdiffstats
path: root/data/dreambooth/prompt.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
commit300deaa789a0321f32d5e7f04d9860eaa258110e (patch)
tree892e89753e5c4d86d787131595751bc03c610be8 /data/dreambooth/prompt.py
parentDefault sample steps 30 -> 40 (diff)
downloadtextual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.gz
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.bz2
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.zip
Add Textual Inversion with class dataset (a la Dreambooth)
Diffstat (limited to 'data/dreambooth/prompt.py')
-rw-r--r--data/dreambooth/prompt.py18
1 files changed, 0 insertions, 18 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
deleted file mode 100644
index b3a83ce..0000000
--- a/data/dreambooth/prompt.py
+++ /dev/null
@@ -1,18 +0,0 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, nprompt, num_samples):
6 self.prompt = prompt
7 self.nprompt = nprompt
8 self.num_samples = num_samples
9
10 def __len__(self):
11 return self.num_samples
12
13 def __getitem__(self, index):
14 example = {}
15 example["prompt"] = self.prompt
16 example["nprompt"] = self.nprompt
17 example["index"] = index
18 return example