diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/dreambooth/csv.py | 11 | ||||
| -rw-r--r-- | data/dreambooth/prompt.py | 18 | 
2 files changed, 4 insertions, 25 deletions
| diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 9075979..abd329d 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 15 | tokenizer, | 15 | tokenizer, | 
| 16 | instance_identifier, | 16 | instance_identifier, | 
| 17 | class_identifier=None, | 17 | class_identifier=None, | 
| 18 | class_subdir="db_cls", | ||
| 18 | size=512, | 19 | size=512, | 
| 19 | repeats=100, | 20 | repeats=100, | 
| 20 | interpolation="bicubic", | 21 | interpolation="bicubic", | 
| @@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 30 | raise ValueError("data_file must be a file") | 31 | raise ValueError("data_file must be a file") | 
| 31 | 32 | ||
| 32 | self.data_root = self.data_file.parent | 33 | self.data_root = self.data_file.parent | 
| 33 | self.class_root = self.data_root.joinpath("db_cls") | 34 | self.class_root = self.data_root.joinpath(class_subdir) | 
| 34 | self.class_root.mkdir(parents=True, exist_ok=True) | 35 | self.class_root.mkdir(parents=True, exist_ok=True) | 
| 35 | 36 | ||
| 36 | self.tokenizer = tokenizer | 37 | self.tokenizer = tokenizer | 
| @@ -140,11 +141,9 @@ class CSVDataset(Dataset): | |||
| 140 | if not instance_image.mode == "RGB": | 141 | if not instance_image.mode == "RGB": | 
| 141 | instance_image = instance_image.convert("RGB") | 142 | instance_image = instance_image.convert("RGB") | 
| 142 | 143 | ||
| 143 | instance_prompt = prompt.format(self.instance_identifier) | ||
| 144 | |||
| 145 | example["instance_images"] = instance_image | 144 | example["instance_images"] = instance_image | 
| 146 | example["instance_prompt_ids"] = self.tokenizer( | 145 | example["instance_prompt_ids"] = self.tokenizer( | 
| 147 | instance_prompt, | 146 | prompt.format(self.instance_identifier), | 
| 148 | padding="do_not_pad", | 147 | padding="do_not_pad", | 
| 149 | truncation=True, | 148 | truncation=True, | 
| 150 | max_length=self.tokenizer.model_max_length, | 149 | max_length=self.tokenizer.model_max_length, | 
| @@ -155,11 +154,9 @@ class CSVDataset(Dataset): | |||
| 155 | if not class_image.mode == "RGB": | 154 | if not class_image.mode == "RGB": | 
| 156 | class_image = class_image.convert("RGB") | 155 | class_image = class_image.convert("RGB") | 
| 157 | 156 | ||
| 158 | class_prompt = prompt.format(self.class_identifier) | ||
| 159 | |||
| 160 | example["class_images"] = class_image | 157 | example["class_images"] = class_image | 
| 161 | example["class_prompt_ids"] = self.tokenizer( | 158 | example["class_prompt_ids"] = self.tokenizer( | 
| 162 | class_prompt, | 159 | prompt.format(self.class_identifier), | 
| 163 | padding="do_not_pad", | 160 | padding="do_not_pad", | 
| 164 | truncation=True, | 161 | truncation=True, | 
| 165 | max_length=self.tokenizer.model_max_length, | 162 | max_length=self.tokenizer.model_max_length, | 
| diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py deleted file mode 100644 index b3a83ce..0000000 --- a/data/dreambooth/prompt.py +++ /dev/null | |||
| @@ -1,18 +0,0 @@ | |||
| 1 | from torch.utils.data import Dataset | ||
| 2 | |||
| 3 | |||
| 4 | class PromptDataset(Dataset): | ||
| 5 | def __init__(self, prompt, nprompt, num_samples): | ||
| 6 | self.prompt = prompt | ||
| 7 | self.nprompt = nprompt | ||
| 8 | self.num_samples = num_samples | ||
| 9 | |||
| 10 | def __len__(self): | ||
| 11 | return self.num_samples | ||
| 12 | |||
| 13 | def __getitem__(self, index): | ||
| 14 | example = {} | ||
| 15 | example["prompt"] = self.prompt | ||
| 16 | example["nprompt"] = self.nprompt | ||
| 17 | example["index"] = index | ||
| 18 | return example | ||
