diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 11 | ||||
-rw-r--r-- | data/dreambooth/prompt.py | 18 |
2 files changed, 4 insertions, 25 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 9075979..abd329d 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
15 | tokenizer, | 15 | tokenizer, |
16 | instance_identifier, | 16 | instance_identifier, |
17 | class_identifier=None, | 17 | class_identifier=None, |
18 | class_subdir="db_cls", | ||
18 | size=512, | 19 | size=512, |
19 | repeats=100, | 20 | repeats=100, |
20 | interpolation="bicubic", | 21 | interpolation="bicubic", |
@@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
30 | raise ValueError("data_file must be a file") | 31 | raise ValueError("data_file must be a file") |
31 | 32 | ||
32 | self.data_root = self.data_file.parent | 33 | self.data_root = self.data_file.parent |
33 | self.class_root = self.data_root.joinpath("db_cls") | 34 | self.class_root = self.data_root.joinpath(class_subdir) |
34 | self.class_root.mkdir(parents=True, exist_ok=True) | 35 | self.class_root.mkdir(parents=True, exist_ok=True) |
35 | 36 | ||
36 | self.tokenizer = tokenizer | 37 | self.tokenizer = tokenizer |
@@ -140,11 +141,9 @@ class CSVDataset(Dataset): | |||
140 | if not instance_image.mode == "RGB": | 141 | if not instance_image.mode == "RGB": |
141 | instance_image = instance_image.convert("RGB") | 142 | instance_image = instance_image.convert("RGB") |
142 | 143 | ||
143 | instance_prompt = prompt.format(self.instance_identifier) | ||
144 | |||
145 | example["instance_images"] = instance_image | 144 | example["instance_images"] = instance_image |
146 | example["instance_prompt_ids"] = self.tokenizer( | 145 | example["instance_prompt_ids"] = self.tokenizer( |
147 | instance_prompt, | 146 | prompt.format(self.instance_identifier), |
148 | padding="do_not_pad", | 147 | padding="do_not_pad", |
149 | truncation=True, | 148 | truncation=True, |
150 | max_length=self.tokenizer.model_max_length, | 149 | max_length=self.tokenizer.model_max_length, |
@@ -155,11 +154,9 @@ class CSVDataset(Dataset): | |||
155 | if not class_image.mode == "RGB": | 154 | if not class_image.mode == "RGB": |
156 | class_image = class_image.convert("RGB") | 155 | class_image = class_image.convert("RGB") |
157 | 156 | ||
158 | class_prompt = prompt.format(self.class_identifier) | ||
159 | |||
160 | example["class_images"] = class_image | 157 | example["class_images"] = class_image |
161 | example["class_prompt_ids"] = self.tokenizer( | 158 | example["class_prompt_ids"] = self.tokenizer( |
162 | class_prompt, | 159 | prompt.format(self.class_identifier), |
163 | padding="do_not_pad", | 160 | padding="do_not_pad", |
164 | truncation=True, | 161 | truncation=True, |
165 | max_length=self.tokenizer.model_max_length, | 162 | max_length=self.tokenizer.model_max_length, |
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py deleted file mode 100644 index b3a83ce..0000000 --- a/data/dreambooth/prompt.py +++ /dev/null | |||
@@ -1,18 +0,0 @@ | |||
1 | from torch.utils.data import Dataset | ||
2 | |||
3 | |||
4 | class PromptDataset(Dataset): | ||
5 | def __init__(self, prompt, nprompt, num_samples): | ||
6 | self.prompt = prompt | ||
7 | self.nprompt = nprompt | ||
8 | self.num_samples = num_samples | ||
9 | |||
10 | def __len__(self): | ||
11 | return self.num_samples | ||
12 | |||
13 | def __getitem__(self, index): | ||
14 | example = {} | ||
15 | example["prompt"] = self.prompt | ||
16 | example["nprompt"] = self.nprompt | ||
17 | example["index"] = index | ||
18 | return example | ||