From 46b6c09a18b41edff77c6881529b66733d788abe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 21:28:52 +0200 Subject: Dreambooth: Generate specialized class images from input prompts --- data/dreambooth/prompt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'data/dreambooth/prompt.py') diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py @@ -2,8 +2,9 @@ from torch.utils.data import Dataset class PromptDataset(Dataset): - def __init__(self, prompt, num_samples): + def __init__(self, prompt, nprompt, num_samples): self.prompt = prompt + self.nprompt = nprompt self.num_samples = num_samples def __len__(self): @@ -12,5 +13,6 @@ class PromptDataset(Dataset): def __getitem__(self, index): example = {} example["prompt"] = self.prompt + example["nprompt"] = self.nprompt example["index"] = index return example -- cgit v1.2.3-54-g00ecf