diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 83 |
1 files changed, 37 insertions, 46 deletions
diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,11 +1,14 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import pandas as pd | 2 | import pandas as pd |
| 3 | import torch | ||
| 3 | from pathlib import Path | 4 | from pathlib import Path |
| 4 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
| 5 | from PIL import Image | 6 | from PIL import Image |
| 6 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 7 | from torchvision import transforms | 8 | from torchvision import transforms |
| 8 | from typing import NamedTuple, List | 9 | from typing import NamedTuple, List, Optional |
| 10 | |||
| 11 | from models.clip.prompt import PromptProcessor | ||
| 9 | 12 | ||
| 10 | 13 | ||
| 11 | class CSVDataItem(NamedTuple): | 14 | class CSVDataItem(NamedTuple): |
| @@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): | |||
| 18 | class CSVDataModule(pl.LightningDataModule): | 21 | class CSVDataModule(pl.LightningDataModule): |
| 19 | def __init__( | 22 | def __init__( |
| 20 | self, | 23 | self, |
| 21 | batch_size, | 24 | batch_size: int, |
| 22 | data_file, | 25 | data_file: str, |
| 23 | tokenizer, | 26 | prompt_processor: PromptProcessor, |
| 24 | instance_identifier, | 27 | instance_identifier: str, |
| 25 | class_identifier=None, | 28 | class_identifier: Optional[str] = None, |
| 26 | class_subdir="cls", | 29 | class_subdir: str = "cls", |
| 27 | num_class_images=100, | 30 | num_class_images: int = 100, |
| 28 | size=512, | 31 | size: int = 512, |
| 29 | repeats=100, | 32 | repeats: int = 1, |
| 30 | interpolation="bicubic", | 33 | interpolation: str = "bicubic", |
| 31 | center_crop=False, | 34 | center_crop: bool = False, |
| 32 | valid_set_size=None, | 35 | valid_set_size: Optional[int] = None, |
| 33 | generator=None, | 36 | generator: Optional[torch.Generator] = None, |
| 34 | collate_fn=None | 37 | collate_fn=None |
| 35 | ): | 38 | ): |
| 36 | super().__init__() | 39 | super().__init__() |
| @@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 45 | self.class_root.mkdir(parents=True, exist_ok=True) | 48 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 46 | self.num_class_images = num_class_images | 49 | self.num_class_images = num_class_images |
| 47 | 50 | ||
| 48 | self.tokenizer = tokenizer | 51 | self.prompt_processor = prompt_processor |
| 49 | self.instance_identifier = instance_identifier | 52 | self.instance_identifier = instance_identifier |
| 50 | self.class_identifier = class_identifier | 53 | self.class_identifier = class_identifier |
| 51 | self.size = size | 54 | self.size = size |
| @@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 65 | self.data_root.joinpath(item.image), | 68 | self.data_root.joinpath(item.image), |
| 66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), |
| 67 | item.prompt, | 70 | item.prompt, |
| 68 | item.nprompt if "nprompt" in item else "" | 71 | item.nprompt |
| 69 | ) | 72 | ) |
| 70 | for item in data | 73 | for item in data |
| 71 | for i in range(image_multiplier) | 74 | for i in range(image_multiplier) |
| @@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 88 | self.data_val = self.prepare_subdata(data_val) | 91 | self.data_val = self.prepare_subdata(data_val) |
| 89 | 92 | ||
| 90 | def setup(self, stage=None): | 93 | def setup(self, stage=None): |
| 91 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| 92 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 95 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 93 | num_class_images=self.num_class_images, | 96 | num_class_images=self.num_class_images, |
| 94 | size=self.size, interpolation=self.interpolation, | 97 | size=self.size, interpolation=self.interpolation, |
| 95 | center_crop=self.center_crop, repeats=self.repeats) | 98 | center_crop=self.center_crop, repeats=self.repeats) |
| 96 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 97 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
| 98 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
| 99 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop, repeats=self.repeats) |
| @@ -113,19 +116,19 @@ class CSVDataset(Dataset): | |||
| 113 | def __init__( | 116 | def __init__( |
| 114 | self, | 117 | self, |
| 115 | data: List[CSVDataItem], | 118 | data: List[CSVDataItem], |
| 116 | tokenizer, | 119 | prompt_processor: PromptProcessor, |
| 117 | instance_identifier, | 120 | instance_identifier: str, |
| 118 | batch_size=1, | 121 | batch_size: int = 1, |
| 119 | class_identifier=None, | 122 | class_identifier: Optional[str] = None, |
| 120 | num_class_images=0, | 123 | num_class_images: int = 0, |
| 121 | size=512, | 124 | size: int = 512, |
| 122 | repeats=1, | 125 | repeats: int = 1, |
| 123 | interpolation="bicubic", | 126 | interpolation: str = "bicubic", |
| 124 | center_crop=False, | 127 | center_crop: bool = False, |
| 125 | ): | 128 | ): |
| 126 | 129 | ||
| 127 | self.data = data | 130 | self.data = data |
| 128 | self.tokenizer = tokenizer | 131 | self.prompt_processor = prompt_processor |
| 129 | self.batch_size = batch_size | 132 | self.batch_size = batch_size |
| 130 | self.instance_identifier = instance_identifier | 133 | self.instance_identifier = instance_identifier |
| 131 | self.class_identifier = class_identifier | 134 | self.class_identifier = class_identifier |
| @@ -163,12 +166,6 @@ class CSVDataset(Dataset): | |||
| 163 | 166 | ||
| 164 | example = {} | 167 | example = {} |
| 165 | 168 | ||
| 166 | if isinstance(item.prompt, str): | ||
| 167 | item.prompt = [item.prompt] | ||
| 168 | |||
| 169 | if isinstance(item.nprompt, str): | ||
| 170 | item.nprompt = [item.nprompt] | ||
| 171 | |||
| 172 | example["prompts"] = item.prompt | 169 | example["prompts"] = item.prompt |
| 173 | example["nprompts"] = item.nprompt | 170 | example["nprompts"] = item.nprompt |
| 174 | 171 | ||
| @@ -181,12 +178,9 @@ class CSVDataset(Dataset): | |||
| 181 | self.image_cache[item.instance_image_path] = instance_image | 178 | self.image_cache[item.instance_image_path] = instance_image |
| 182 | 179 | ||
| 183 | example["instance_images"] = instance_image | 180 | example["instance_images"] = instance_image |
| 184 | example["instance_prompt_ids"] = self.tokenizer( | 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 185 | item.prompt.format(self.instance_identifier), | 182 | item.prompt.format(self.instance_identifier) |
| 186 | padding="max_length", | 183 | ) |
| 187 | truncation=True, | ||
| 188 | max_length=self.tokenizer.model_max_length, | ||
| 189 | ).input_ids | ||
| 190 | 184 | ||
| 191 | if self.num_class_images != 0: | 185 | if self.num_class_images != 0: |
| 192 | class_image = Image.open(item.class_image_path) | 186 | class_image = Image.open(item.class_image_path) |
| @@ -194,12 +188,9 @@ class CSVDataset(Dataset): | |||
| 194 | class_image = class_image.convert("RGB") | 188 | class_image = class_image.convert("RGB") |
| 195 | 189 | ||
| 196 | example["class_images"] = class_image | 190 | example["class_images"] = class_image |
| 197 | example["class_prompt_ids"] = self.tokenizer( | 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 198 | item.prompt.format(self.class_identifier), | 192 | item.nprompt.format(self.class_identifier) |
| 199 | padding="max_length", | 193 | ) |
| 200 | truncation=True, | ||
| 201 | max_length=self.tokenizer.model_max_length, | ||
| 202 | ).input_ids | ||
| 203 | 194 | ||
| 204 | self.cache[item.instance_image_path] = example | 195 | self.cache[item.instance_image_path] = example |
| 205 | return example | 196 | return example |
