diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
commit | 46b6c09a18b41edff77c6881529b66733d788abe (patch) | |
tree | 670e7cdda37ba7a010b570398a63dd38e357b6ce /data/dreambooth/prompt.py | |
parent | Small perf improvements (diff) | |
download | textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2 textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip |
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'data/dreambooth/prompt.py')
-rw-r--r-- | data/dreambooth/prompt.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py | |||
@@ -2,8 +2,9 @@ from torch.utils.data import Dataset | |||
2 | 2 | ||
3 | 3 | ||
4 | class PromptDataset(Dataset): | 4 | class PromptDataset(Dataset): |
5 | def __init__(self, prompt, num_samples): | 5 | def __init__(self, prompt, nprompt, num_samples): |
6 | self.prompt = prompt | 6 | self.prompt = prompt |
7 | self.nprompt = nprompt | ||
7 | self.num_samples = num_samples | 8 | self.num_samples = num_samples |
8 | 9 | ||
9 | def __len__(self): | 10 | def __len__(self): |
@@ -12,5 +13,6 @@ class PromptDataset(Dataset): | |||
12 | def __getitem__(self, index): | 13 | def __getitem__(self, index): |
13 | example = {} | 14 | example = {} |
14 | example["prompt"] = self.prompt | 15 | example["prompt"] = self.prompt |
16 | example["nprompt"] = self.nprompt | ||
15 | example["index"] = index | 17 | example["index"] = index |
16 | return example | 18 | return example |