diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 144 | 
1 files changed, 88 insertions, 56 deletions
| diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor | |||
| 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | image_cache: dict[str, Image.Image] = {} | ||
| 15 | |||
| 16 | |||
| 17 | def get_image(path): | ||
| 18 | if path in image_cache: | ||
| 19 | return image_cache[path] | ||
| 20 | |||
| 21 | image = Image.open(path) | ||
| 22 | if not image.mode == "RGB": | ||
| 23 | image = image.convert("RGB") | ||
| 24 | image_cache[path] = image | ||
| 25 | |||
| 26 | return image | ||
| 27 | |||
| 28 | |||
| 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 29 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 30 | return {"content": prompt} if isinstance(prompt, str) else prompt | 
| 16 | 31 | ||
| 17 | 32 | ||
| 18 | class CSVDataItem(NamedTuple): | 33 | class VlpnDataItem(NamedTuple): | 
| 19 | instance_image_path: Path | 34 | instance_image_path: Path | 
| 20 | class_image_path: Path | 35 | class_image_path: Path | 
| 21 | prompt: list[str] | 36 | prompt: list[str] | 
| @@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): | |||
| 24 | collection: list[str] | 39 | collection: list[str] | 
| 25 | 40 | ||
| 26 | 41 | ||
| 27 | class CSVDataModule(): | 42 | class VlpnDataBucket(): | 
| 43 | def __init__(self, width: int, height: int): | ||
| 44 | self.width = width | ||
| 45 | self.height = height | ||
| 46 | self.ratio = width / height | ||
| 47 | self.items: list[VlpnDataItem] = [] | ||
| 48 | |||
| 49 | |||
| 50 | class VlpnDataModule(): | ||
| 28 | def __init__( | 51 | def __init__( | 
| 29 | self, | 52 | self, | 
| 30 | batch_size: int, | 53 | batch_size: int, | 
| @@ -36,11 +59,10 @@ class CSVDataModule(): | |||
| 36 | repeats: int = 1, | 59 | repeats: int = 1, | 
| 37 | dropout: float = 0, | 60 | dropout: float = 0, | 
| 38 | interpolation: str = "bicubic", | 61 | interpolation: str = "bicubic", | 
| 39 | center_crop: bool = False, | ||
| 40 | template_key: str = "template", | 62 | template_key: str = "template", | 
| 41 | valid_set_size: Optional[int] = None, | 63 | valid_set_size: Optional[int] = None, | 
| 42 | seed: Optional[int] = None, | 64 | seed: Optional[int] = None, | 
| 43 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 65 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 
| 44 | collate_fn=None, | 66 | collate_fn=None, | 
| 45 | num_workers: int = 0 | 67 | num_workers: int = 0 | 
| 46 | ): | 68 | ): | 
| @@ -60,7 +82,6 @@ class CSVDataModule(): | |||
| 60 | self.size = size | 82 | self.size = size | 
| 61 | self.repeats = repeats | 83 | self.repeats = repeats | 
| 62 | self.dropout = dropout | 84 | self.dropout = dropout | 
| 63 | self.center_crop = center_crop | ||
| 64 | self.template_key = template_key | 85 | self.template_key = template_key | 
| 65 | self.interpolation = interpolation | 86 | self.interpolation = interpolation | 
| 66 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size | 
| @@ -70,14 +91,14 @@ class CSVDataModule(): | |||
| 70 | self.num_workers = num_workers | 91 | self.num_workers = num_workers | 
| 71 | self.batch_size = batch_size | 92 | self.batch_size = batch_size | 
| 72 | 93 | ||
| 73 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 94 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 
| 74 | image = template["image"] if "image" in template else "{}" | 95 | image = template["image"] if "image" in template else "{}" | 
| 75 | prompt = template["prompt"] if "prompt" in template else "{content}" | 96 | prompt = template["prompt"] if "prompt" in template else "{content}" | 
| 76 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 97 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 
| 77 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 98 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 
| 78 | 99 | ||
| 79 | return [ | 100 | return [ | 
| 80 | CSVDataItem( | 101 | VlpnDataItem( | 
| 81 | self.data_root.joinpath(image.format(item["image"])), | 102 | self.data_root.joinpath(image.format(item["image"])), | 
| 82 | None, | 103 | None, | 
| 83 | prompt_to_keywords( | 104 | prompt_to_keywords( | 
| @@ -97,17 +118,17 @@ class CSVDataModule(): | |||
| 97 | for item in data | 118 | for item in data | 
| 98 | ] | 119 | ] | 
| 99 | 120 | ||
| 100 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 121 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: | 
| 101 | if self.filter is None: | 122 | if self.filter is None: | 
| 102 | return items | 123 | return items | 
| 103 | 124 | ||
| 104 | return [item for item in items if self.filter(item)] | 125 | return [item for item in items if self.filter(item)] | 
| 105 | 126 | ||
| 106 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 127 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: | 
| 107 | image_multiplier = max(num_class_images, 1) | 128 | image_multiplier = max(num_class_images, 1) | 
| 108 | 129 | ||
| 109 | return [ | 130 | return [ | 
| 110 | CSVDataItem( | 131 | VlpnDataItem( | 
| 111 | item.instance_image_path, | 132 | item.instance_image_path, | 
| 112 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 133 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 
| 113 | item.prompt, | 134 | item.prompt, | 
| @@ -119,7 +140,30 @@ class CSVDataModule(): | |||
| 119 | for i in range(image_multiplier) | 140 | for i in range(image_multiplier) | 
| 120 | ] | 141 | ] | 
| 121 | 142 | ||
| 122 | def prepare_data(self): | 143 | def generate_buckets(self, items: list[VlpnDataItem]): | 
| 144 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
| 145 | |||
| 146 | for i in range(1, 5): | ||
| 147 | s = self.size + i * 64 | ||
| 148 | buckets.append(VlpnDataBucket(s, self.size)) | ||
| 149 | buckets.append(VlpnDataBucket(self.size, s)) | ||
| 150 | |||
| 151 | for item in items: | ||
| 152 | image = get_image(item.instance_image_path) | ||
| 153 | ratio = image.width / image.height | ||
| 154 | |||
| 155 | if ratio >= 1: | ||
| 156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | ||
| 157 | else: | ||
| 158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | ||
| 159 | |||
| 160 | for bucket in candidates: | ||
| 161 | bucket.items.append(item) | ||
| 162 | |||
| 163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | ||
| 164 | return buckets | ||
| 165 | |||
| 166 | def setup(self): | ||
| 123 | with open(self.data_file, 'rt') as f: | 167 | with open(self.data_file, 'rt') as f: | 
| 124 | metadata = json.load(f) | 168 | metadata = json.load(f) | 
| 125 | template = metadata[self.template_key] if self.template_key in metadata else {} | 169 | template = metadata[self.template_key] if self.template_key in metadata else {} | 
| @@ -144,48 +188,48 @@ class CSVDataModule(): | |||
| 144 | self.data_train = self.pad_items(data_train, self.num_class_images) | 188 | self.data_train = self.pad_items(data_train, self.num_class_images) | 
| 145 | self.data_val = self.pad_items(data_val) | 189 | self.data_val = self.pad_items(data_val) | 
| 146 | 190 | ||
| 147 | def setup(self, stage=None): | 191 | buckets = self.generate_buckets(data_train) | 
| 148 | train_dataset = CSVDataset( | 192 | |
| 149 | self.data_train, self.prompt_processor, batch_size=self.batch_size, | 193 | train_datasets = [ | 
| 150 | num_class_images=self.num_class_images, | 194 | VlpnDataset( | 
| 151 | size=self.size, interpolation=self.interpolation, | 195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, | 
| 152 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout | 196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 
| 153 | ) | 197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 
| 154 | val_dataset = CSVDataset( | 198 | ) | 
| 155 | self.data_val, self.prompt_processor, batch_size=self.batch_size, | 199 | for bucket in buckets | 
| 156 | size=self.size, interpolation=self.interpolation, | 200 | ] | 
| 157 | center_crop=self.center_crop | 201 | |
| 158 | ) | 202 | val_dataset = VlpnDataset( | 
| 159 | self.train_dataloader_ = DataLoader( | 203 | data_val, self.prompt_processor, batch_size=self.batch_size, | 
| 160 | train_dataset, batch_size=self.batch_size, | 204 | width=self.size, height=self.size, interpolation=self.interpolation, | 
| 161 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
| 162 | num_workers=self.num_workers | ||
| 163 | ) | ||
| 164 | self.val_dataloader_ = DataLoader( | ||
| 165 | val_dataset, batch_size=self.batch_size, | ||
| 166 | pin_memory=True, collate_fn=self.collate_fn, | ||
| 167 | num_workers=self.num_workers | ||
| 168 | ) | 205 | ) | 
| 169 | 206 | ||
| 170 | def train_dataloader(self): | 207 | self.train_dataloaders = [ | 
| 171 | return self.train_dataloader_ | 208 | DataLoader( | 
| 209 | dataset, batch_size=self.batch_size, shuffle=True, | ||
| 210 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
| 211 | ) | ||
| 212 | for dataset in train_datasets | ||
| 213 | ] | ||
| 172 | 214 | ||
| 173 | def val_dataloader(self): | 215 | self.val_dataloader = DataLoader( | 
| 174 | return self.val_dataloader_ | 216 | val_dataset, batch_size=self.batch_size, | 
| 217 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
| 218 | ) | ||
| 175 | 219 | ||
| 176 | 220 | ||
| 177 | class CSVDataset(Dataset): | 221 | class VlpnDataset(Dataset): | 
| 178 | def __init__( | 222 | def __init__( | 
| 179 | self, | 223 | self, | 
| 180 | data: List[CSVDataItem], | 224 | data: List[VlpnDataItem], | 
| 181 | prompt_processor: PromptProcessor, | 225 | prompt_processor: PromptProcessor, | 
| 182 | batch_size: int = 1, | 226 | batch_size: int = 1, | 
| 183 | num_class_images: int = 0, | 227 | num_class_images: int = 0, | 
| 184 | size: int = 768, | 228 | width: int = 768, | 
| 229 | height: int = 768, | ||
| 185 | repeats: int = 1, | 230 | repeats: int = 1, | 
| 186 | dropout: float = 0, | 231 | dropout: float = 0, | 
| 187 | interpolation: str = "bicubic", | 232 | interpolation: str = "bicubic", | 
| 188 | center_crop: bool = False, | ||
| 189 | ): | 233 | ): | 
| 190 | 234 | ||
| 191 | self.data = data | 235 | self.data = data | 
| @@ -193,7 +237,6 @@ class CSVDataset(Dataset): | |||
| 193 | self.batch_size = batch_size | 237 | self.batch_size = batch_size | 
| 194 | self.num_class_images = num_class_images | 238 | self.num_class_images = num_class_images | 
| 195 | self.dropout = dropout | 239 | self.dropout = dropout | 
| 196 | self.image_cache = {} | ||
| 197 | 240 | ||
| 198 | self.num_instance_images = len(self.data) | 241 | self.num_instance_images = len(self.data) | 
| 199 | self._length = self.num_instance_images * repeats | 242 | self._length = self.num_instance_images * repeats | 
| @@ -206,8 +249,8 @@ class CSVDataset(Dataset): | |||
| 206 | }[interpolation] | 249 | }[interpolation] | 
| 207 | self.image_transforms = transforms.Compose( | 250 | self.image_transforms = transforms.Compose( | 
| 208 | [ | 251 | [ | 
| 209 | transforms.Resize(size, interpolation=self.interpolation), | 252 | transforms.Resize(min(width, height), interpolation=self.interpolation), | 
| 210 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 253 | transforms.RandomCrop((height, width)), | 
| 211 | transforms.RandomHorizontalFlip(), | 254 | transforms.RandomHorizontalFlip(), | 
| 212 | transforms.ToTensor(), | 255 | transforms.ToTensor(), | 
| 213 | transforms.Normalize([0.5], [0.5]), | 256 | transforms.Normalize([0.5], [0.5]), | 
| @@ -217,17 +260,6 @@ class CSVDataset(Dataset): | |||
| 217 | def __len__(self): | 260 | def __len__(self): | 
| 218 | return math.ceil(self._length / self.batch_size) * self.batch_size | 261 | return math.ceil(self._length / self.batch_size) * self.batch_size | 
| 219 | 262 | ||
| 220 | def get_image(self, path): | ||
| 221 | if path in self.image_cache: | ||
| 222 | return self.image_cache[path] | ||
| 223 | |||
| 224 | image = Image.open(path) | ||
| 225 | if not image.mode == "RGB": | ||
| 226 | image = image.convert("RGB") | ||
| 227 | self.image_cache[path] = image | ||
| 228 | |||
| 229 | return image | ||
| 230 | |||
| 231 | def get_example(self, i): | 263 | def get_example(self, i): | 
| 232 | item = self.data[i % self.num_instance_images] | 264 | item = self.data[i % self.num_instance_images] | 
| 233 | 265 | ||
| @@ -235,9 +267,9 @@ class CSVDataset(Dataset): | |||
| 235 | example["prompts"] = item.prompt | 267 | example["prompts"] = item.prompt | 
| 236 | example["cprompts"] = item.cprompt | 268 | example["cprompts"] = item.cprompt | 
| 237 | example["nprompts"] = item.nprompt | 269 | example["nprompts"] = item.nprompt | 
| 238 | example["instance_images"] = self.get_image(item.instance_image_path) | 270 | example["instance_images"] = get_image(item.instance_image_path) | 
| 239 | if self.num_class_images != 0: | 271 | if self.num_class_images != 0: | 
| 240 | example["class_images"] = self.get_image(item.class_image_path) | 272 | example["class_images"] = get_image(item.class_image_path) | 
| 241 | 273 | ||
| 242 | return example | 274 | return example | 
| 243 | 275 | ||
